Skip to content

Commit 9252dc5

Browse files
fix
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent ae7e648 commit 9252dc5

File tree

2 files changed

+77
-77
lines changed

2 files changed

+77
-77
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,73 +5,30 @@
55

66
/*
77
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8-
NVIDIA GPUs with sm90a (Hopper).
8+
NVIDIA GPUs with sm100 (Blackwell).
99
*/
1010

11-
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
11+
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
1212

13-
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
14-
torch::Tensor const& b,
15-
torch::Tensor const& a_scales,
16-
torch::Tensor const& b_scales,
17-
std::optional<torch::Tensor> const& bias) {
13+
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
14+
torch::Tensor const& b,
15+
torch::Tensor const& a_scales,
16+
torch::Tensor const& b_scales,
17+
std::optional<torch::Tensor> const& bias) {
1818
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
1919
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
2020

2121
int M = a.size(0), N = b.size(1), K = a.size(1);
22-
23-
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
25-
// Standard per-tensor/per-token/per-channel scaling
26-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
27-
if (a.dtype() == torch::kFloat8_e4m3fn) {
28-
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
29-
} else {
30-
TORCH_CHECK(a.dtype() == torch::kInt8);
31-
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
32-
}
33-
} else {
34-
using GroupShape = std::array<int64_t, 2>;
35-
auto make_group_shape = [](torch::Tensor const& x,
36-
torch::Tensor const& s) -> GroupShape {
37-
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
38-
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
39-
cuda_utils::ceil_div(x.size(1), s.size(1))};
40-
};
41-
42-
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
43-
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
44-
45-
// 1x128 per-token group scales for activations
46-
// 128x128 blockwise scales for weights
47-
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
48-
b_scale_group_shape == GroupShape{128, 128} &&
49-
a.dtype() == torch::kFloat8_e4m3fn &&
50-
b.dtype() == torch::kFloat8_e4m3fn),
51-
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
52-
"a_scale_group_shape must be [1, 128]. Got: [",
53-
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
54-
"]\n"
55-
"b_scale_group_shape must be [128, 128]. Got: [",
56-
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
57-
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
58-
59-
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
60-
}
61-
}
62-
63-
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
64-
torch::Tensor const& b,
65-
torch::Tensor const& a_scales,
66-
torch::Tensor const& b_scales,
67-
torch::Tensor const& azp_adj,
68-
std::optional<torch::Tensor> const& azp,
69-
std::optional<torch::Tensor> const& bias) {
70-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
71-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
72-
73-
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
74-
azp, bias);
22+
TORCH_CHECK(
23+
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25+
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
26+
27+
// Standard per-tensor/per-token/per-channel scaling
28+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30+
"Currently, only fp8 gemm is implemented for Blackwell");
31+
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
7532
}
7633

7734
#endif

csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,73 @@
55

66
/*
77
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8-
NVIDIA GPUs with sm100 (Blackwell).
8+
NVIDIA GPUs with sm90a (Hopper).
99
*/
1010

11-
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
11+
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
1212

13-
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
14-
torch::Tensor const& b,
15-
torch::Tensor const& a_scales,
16-
torch::Tensor const& b_scales,
17-
std::optional<torch::Tensor> const& bias) {
13+
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
14+
torch::Tensor const& b,
15+
torch::Tensor const& a_scales,
16+
torch::Tensor const& b_scales,
17+
std::optional<torch::Tensor> const& bias) {
1818
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
1919
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
2020

2121
int M = a.size(0), N = b.size(1), K = a.size(1);
22-
TORCH_CHECK(
23-
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25-
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
26-
27-
// Standard per-tensor/per-token/per-channel scaling
28-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30-
"Currently, only fp8 gemm is implemented for Blackwell");
31-
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
22+
23+
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
25+
// Standard per-tensor/per-token/per-channel scaling
26+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
27+
if (a.dtype() == torch::kFloat8_e4m3fn) {
28+
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
29+
} else {
30+
TORCH_CHECK(a.dtype() == torch::kInt8);
31+
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
32+
}
33+
} else {
34+
using GroupShape = std::array<int64_t, 2>;
35+
auto make_group_shape = [](torch::Tensor const& x,
36+
torch::Tensor const& s) -> GroupShape {
37+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
38+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
39+
cuda_utils::ceil_div(x.size(1), s.size(1))};
40+
};
41+
42+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
43+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
44+
45+
// 1x128 per-token group scales for activations
46+
// 128x128 blockwise scales for weights
47+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
48+
b_scale_group_shape == GroupShape{128, 128} &&
49+
a.dtype() == torch::kFloat8_e4m3fn &&
50+
b.dtype() == torch::kFloat8_e4m3fn),
51+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
52+
"a_scale_group_shape must be [1, 128]. Got: [",
53+
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
54+
"]\n"
55+
"b_scale_group_shape must be [128, 128]. Got: [",
56+
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
57+
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
58+
59+
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
60+
}
61+
}
62+
63+
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
64+
torch::Tensor const& b,
65+
torch::Tensor const& a_scales,
66+
torch::Tensor const& b_scales,
67+
torch::Tensor const& azp_adj,
68+
std::optional<torch::Tensor> const& azp,
69+
std::optional<torch::Tensor> const& bias) {
70+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
71+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
72+
73+
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
74+
azp, bias);
3275
}
3376

3477
#endif

0 commit comments

Comments
 (0)