Skip to content

Commit fc183b5

Browse files
committed
fix merge conflict
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 2fdf274 commit fc183b5

File tree

1 file changed

+92
-62
lines changed

1 file changed

+92
-62
lines changed

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 92 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,9 @@ def apply(
626626
raise NotImplementedError
627627

628628

629-
def _slice_scales(scales: Optional[torch.Tensor], start: int,
630-
end: int) -> Optional[torch.Tensor]:
629+
def _slice_scales(
630+
scales: Optional[torch.Tensor], start: int, end: int
631+
) -> Optional[torch.Tensor]:
631632
if scales is not None:
632633
if scales.numel() == 1:
633634
return scales
@@ -640,8 +641,9 @@ class SharedResizableBuffer:
640641
def __init__(self):
641642
self.buffer = None
642643

643-
def get(self, shape: tuple[int, ...], device: torch.device,
644-
dtype: torch.dtype) -> torch.Tensor:
644+
def get(
645+
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
646+
) -> torch.Tensor:
645647
assert shape != ()
646648
shape_numel = prod(shape)
647649
if (
@@ -717,8 +719,11 @@ def _chunk_info(self, M: int) -> tuple[int, int]:
717719
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
718720
If there are no tokens to process, the number of chunks will be zero.
719721
"""
720-
CHUNK_SIZE = (max(M, 1) if not self.fused_experts.supports_chunking()
721-
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE))
722+
CHUNK_SIZE = (
723+
max(M, 1)
724+
if not self.fused_experts.supports_chunking()
725+
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
726+
)
722727
num_chunks = cdiv(M, CHUNK_SIZE)
723728
# If there are no tokens, then there should be no loop iterations.
724729
assert M > 0 or num_chunks == 0
@@ -755,31 +760,37 @@ def _allocate_buffers(
755760
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
756761

757762
workspace13_shape, workspace2_shape, fused_out_shape = (
758-
self.fused_experts.workspace_shapes(M_chunk, M_full, N, K, top_k,
759-
global_num_experts,
760-
local_num_experts,
761-
expert_tokens_meta))
763+
self.fused_experts.workspace_shapes(
764+
M_chunk,
765+
M_full,
766+
N,
767+
K,
768+
top_k,
769+
global_num_experts,
770+
local_num_experts,
771+
expert_tokens_meta,
772+
)
773+
)
762774

763775
# We can reuse the memory between cache1 and cache3 because by the
764776
# time we need cache3, we're done with cache1.
765-
workspace13 = buffers.workspace13.get(workspace13_shape,
766-
device=device,
767-
dtype=workspace_dtype)
768-
workspace2 = buffers.workspace2.get(workspace2_shape,
769-
device=device,
770-
dtype=workspace_dtype)
777+
workspace13 = buffers.workspace13.get(
778+
workspace13_shape, device=device, dtype=workspace_dtype
779+
)
780+
workspace2 = buffers.workspace2.get(
781+
workspace2_shape, device=device, dtype=workspace_dtype
782+
)
771783

772784
# Construct the entire output that can then be processed in chunks.
773785
# Reuse workspace13 for the output in the non-chunked case as long
774786
# as it is large enough. This will not always be the case for standard
775787
# format experts and with experts that have empty workspaces.
776-
if num_chunks == 1 and prod(workspace13_shape) >= prod(
777-
fused_out_shape):
788+
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
778789
fused_out = _resize_cache(workspace13, fused_out_shape)
779790
else:
780-
fused_out = buffers.fused_out.get(fused_out_shape,
781-
device=device,
782-
dtype=out_dtype)
791+
fused_out = buffers.fused_out.get(
792+
fused_out_shape, device=device, dtype=out_dtype
793+
)
783794

784795
return workspace13, workspace2, fused_out
785796

@@ -794,8 +805,7 @@ def _slice_output_tensor(
794805
if num_chunks == 1:
795806
return fused_out
796807

797-
assert fused_out.size(0) % M == 0, (
798-
f"fused_out shape {fused_out.shape} vs M {M}")
808+
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
799809
factor = fused_out.size(0) // M
800810
out_chunk_size = CHUNK_SIZE * factor
801811
s = chunk_idx * out_chunk_size
@@ -816,23 +826,24 @@ def _slice_expert_tokens_metadata(
816826
# The existing expert_num_tokens is for the entire a1q
817827
# input. Chunking forces recomputation of the number
818828
# of tokens assigned to each expert.
819-
c_expert_num_tokens = count_expert_num_tokens(chunk_topk_ids,
820-
local_num_experts,
821-
expert_map)
829+
c_expert_num_tokens = count_expert_num_tokens(
830+
chunk_topk_ids, local_num_experts, expert_map
831+
)
822832

823833
c_expert_num_tokens_cpu = None
824834
need_expert_num_tokens_cpu = (
825-
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
835+
full_expert_tokens_meta.expert_num_tokens_cpu is not None
836+
)
826837
if need_expert_num_tokens_cpu:
827838
# This is blocking as some implementations need the count
828839
# on the CPU to determine appropriate input/out fused-moe
829840
# buffers
830-
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
831-
"cpu", non_blocking=False)
841+
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
832842

833843
return ExpertTokensMetadata(
834844
expert_num_tokens=c_expert_num_tokens,
835-
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
845+
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
846+
)
836847

837848
def _prepare(
838849
self,
@@ -843,11 +854,11 @@ def _prepare(
843854
expert_map: Optional[torch.Tensor],
844855
apply_router_weight_on_input: bool,
845856
) -> tuple[
846-
torch.Tensor,
847-
Optional[torch.Tensor],
848-
Optional[ExpertTokensMetadata],
849-
torch.Tensor,
850-
torch.Tensor,
857+
torch.Tensor,
858+
Optional[torch.Tensor],
859+
Optional[ExpertTokensMetadata],
860+
torch.Tensor,
861+
torch.Tensor,
851862
]:
852863
"""
853864
The _prepare method is a wrapper around self.prepare_finalize.prepare
@@ -859,16 +870,21 @@ def _prepare(
859870
# TODO(lucas): enable in follow-up
860871
assert not dbo_enabled()
861872

862-
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
863-
_expert_topk_weights) = self.prepare_finalize.prepare(
864-
hidden_states,
865-
topk_weights,
866-
topk_ids,
867-
global_num_experts,
868-
expert_map,
869-
apply_router_weight_on_input,
870-
self.fused_experts.quant_config,
871-
)
873+
(
874+
a1q,
875+
a1q_scale,
876+
expert_tokens_meta,
877+
_expert_topk_ids,
878+
_expert_topk_weights,
879+
) = self.prepare_finalize.prepare(
880+
hidden_states,
881+
topk_weights,
882+
topk_ids,
883+
global_num_experts,
884+
expert_map,
885+
apply_router_weight_on_input,
886+
self.fused_experts.quant_config,
887+
)
872888
else:
873889
# Overlap shared expert compute with all2all dispatch.
874890
dbo_maybe_run_recv_hook()
@@ -931,7 +947,9 @@ def _fused_experts(
931947
apply_router_weight_on_input: bool,
932948
expert_tokens_meta: Optional[ExpertTokensMetadata],
933949
) -> torch.Tensor:
934-
_, M_full, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
950+
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
951+
a1q, w1, w2, topk_ids
952+
)
935953

936954
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
937955

@@ -959,19 +977,32 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
959977
else:
960978
assert num_chunks > 0
961979
workspace13, workspace2, fused_out = self._allocate_buffers(
962-
in_dtype, a1q.device, CHUNK_SIZE, M_full, N, K, top_k,
963-
global_num_experts, local_num_experts, expert_tokens_meta)
980+
in_dtype,
981+
a1q.device,
982+
CHUNK_SIZE,
983+
M_full,
984+
N,
985+
K,
986+
top_k,
987+
global_num_experts,
988+
local_num_experts,
989+
expert_tokens_meta,
990+
)
964991

965992
for chunk_idx in range(num_chunks):
966993
s, e = input_chunk_range(chunk_idx)
967994

968995
c_expert_tokens_meta = self._slice_expert_tokens_metadata(
969-
num_chunks, expert_tokens_meta, topk_ids[s:e],
970-
local_num_experts, expert_map)
996+
num_chunks,
997+
expert_tokens_meta,
998+
topk_ids[s:e],
999+
local_num_experts,
1000+
expert_map,
1001+
)
9711002

972-
c_fused_out = self._slice_output_tensor(fused_out, chunk_idx,
973-
num_chunks, CHUNK_SIZE,
974-
M_full)
1003+
c_fused_out = self._slice_output_tensor(
1004+
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
1005+
)
9751006

9761007
self.fused_experts.apply(
9771008
output=c_fused_out,
@@ -1111,15 +1142,14 @@ def forward(
11111142
if global_num_experts == -1:
11121143
global_num_experts = local_num_experts
11131144

1114-
a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = (
1115-
self._prepare(
1116-
hidden_states,
1117-
topk_weights,
1118-
topk_ids,
1119-
global_num_experts,
1120-
expert_map,
1121-
apply_router_weight_on_input,
1122-
))
1145+
a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare(
1146+
hidden_states,
1147+
topk_weights,
1148+
topk_ids,
1149+
global_num_experts,
1150+
expert_map,
1151+
apply_router_weight_on_input,
1152+
)
11231153

11241154
fused_out = self._fused_experts(
11251155
in_dtype=hidden_states.dtype,

0 commit comments

Comments
 (0)