Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Enable 8-bit weights in Fused Marlin MoE #8032

Merged
merged 30 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0abac6f
Enable 8-bit weights in Fused Marlin MoE
ElizaWszola Aug 30, 2024
fdf69c2
fix rocm
ElizaWszola Aug 30, 2024
4da163b
bad paste
ElizaWszola Aug 30, 2024
21d2337
add test case; fix imports for tests
dsikka Aug 30, 2024
080ab23
Merge branch 'main' into marlin-moe-8-bit
dsikka Aug 30, 2024
638777a
fix to adapt custom_routin_function
dsikka Aug 30, 2024
bd4b84d
Use select_experts to compute top_k tensors in fused moe
ElizaWszola Sep 2, 2024
bef6b53
bring back fused_moe_marlin -> fused_marlin_moe
ElizaWszola Sep 3, 2024
befc52b
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 4, 2024
b45594c
remove large model
dsikka Sep 4, 2024
effd2cd
Cleanup, comments
ElizaWszola Sep 4, 2024
52c3353
fix moe init
ElizaWszola Sep 4, 2024
882fd9c
move larger models to an options larger test
dsikka Sep 4, 2024
973d914
add optional flag
dsikka Sep 4, 2024
72bc899
swap gpu
dsikka Sep 5, 2024
eea2bc3
Temp disable part of moe tests to see what's breaking
ElizaWszola Sep 5, 2024
9c29dc2
Fixes to act_order, make unit tests more robust
ElizaWszola Sep 5, 2024
6d04dcd
try to narrow down cuda error
ElizaWszola Sep 5, 2024
83e7999
Try different subset of test params
ElizaWszola Sep 6, 2024
6a42eaf
.
ElizaWszola Sep 6, 2024
3288842
.
ElizaWszola Sep 6, 2024
61ef4ba
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 10, 2024
667d23e
fix and cleanup after merge
ElizaWszola Sep 10, 2024
b16838e
cleanup
ElizaWszola Sep 10, 2024
e53abb9
validate cache for the kernel code
ElizaWszola Sep 10, 2024
2f82715
cleanup commented out code
ElizaWszola Sep 11, 2024
f97b524
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 13, 2024
771f693
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 13, 2024
aac7c20
fix type conversion
ElizaWszola Sep 13, 2024
9d7caad
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
537 changes: 389 additions & 148 deletions csrc/moe/marlin_moe_ops.cu

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

#include <torch/all.h>

#include "core/scalar_type.hpp"

torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
8 changes: 5 additions & 3 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");
"g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
18 changes: 12 additions & 6 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe(
m: int,
n: int,
Expand All @@ -148,6 +149,7 @@ def test_fused_marlin_moe(
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
torch.manual_seed(7)

Expand All @@ -161,13 +163,12 @@ def test_fused_marlin_moe(
if group_size in (k, n):
return

quant_type = scalar_types.uint4b8
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
Expand Down Expand Up @@ -240,6 +241,7 @@ def test_fused_marlin_moe(
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2
Expand All @@ -254,14 +256,16 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
@pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
if topk > e:
return
Expand All @@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
if group_size == k:
return

quant_type = scalar_types.uint4b8
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
g_idx,
sort_indices,
topk,
renormalize=False)
renormalize=False,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
3 changes: 2 additions & 1 deletion tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
Empty file modified tests/weight_loading/run_model_weight_loading_test.sh
100644 → 100755
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on this with this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some accidental changes in it when merging, and when I reverted them, it just showed no lines in diff

I don't know why it's still there...

Empty file.
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * 2),
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
Expand Down
44 changes: 30 additions & 14 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types


def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Expand All @@ -36,6 +39,7 @@ def single_marlin_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand All @@ -48,10 +52,11 @@ def single_marlin_moe(
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]

M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // 2
N = w.shape[2] // (num_bits // 2)

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
Expand All @@ -76,10 +81,13 @@ def single_marlin_moe(
device="cuda",
requires_grad=False)

scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)

intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
False)
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
block_size_m, True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)

Expand All @@ -98,6 +106,7 @@ def fused_marlin_moe(
override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand All @@ -122,6 +131,7 @@ def fused_marlin_moe(
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand All @@ -131,13 +141,14 @@ def fused_marlin_moe(
0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]

M, K = hidden_states.shape
E = w1.shape[0]
Expand Down Expand Up @@ -165,6 +176,9 @@ def fused_marlin_moe(
device="cuda",
requires_grad=False)

scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
Expand All @@ -181,6 +195,7 @@ def fused_marlin_moe(
g_idx1,
perm1,
workspace,
scalar_type,
M,
2 * N,
K,
Expand All @@ -204,6 +219,7 @@ def fused_marlin_moe(
g_idx2,
perm2,
workspace,
scalar_type,
M,
K,
N,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -38,10 +40,11 @@ def __init__(

if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits == 4):
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for 4 bits")
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
Expand Down Expand Up @@ -292,4 +295,5 @@ def apply(
topk_ids,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
num_bits=self.num_bits,
)
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,5 @@ def apply(
topk_ids,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype)
8 changes: 1 addition & 7 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"]
# for gptq_marlin, only run fused MoE for int4
if model_config.quantization == "gptq_marlin":
hf_quant_config = getattr(model_config.hf_config,
"quantization_config", None)
if hf_quant_config and hf_quant_config.get("bits") == 4:
mixtral_supported.append("gptq_marlin")
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]

if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
Expand Down
Loading