|
5 | 5 |
|
6 | 6 | /* |
7 | 7 | This file defines quantized GEMM operations using the CUTLASS 3.x API, for |
8 | | - NVIDIA GPUs with sm90a (Hopper). |
| 8 | + NVIDIA GPUs with sm100 (Blackwell). |
9 | 9 | */ |
10 | 10 |
|
11 | | -#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 |
| 11 | +#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 |
12 | 12 |
|
13 | | -void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, |
14 | | - torch::Tensor const& b, |
15 | | - torch::Tensor const& a_scales, |
16 | | - torch::Tensor const& b_scales, |
17 | | - std::optional<torch::Tensor> const& bias) { |
| 13 | +void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, |
| 14 | + torch::Tensor const& b, |
| 15 | + torch::Tensor const& a_scales, |
| 16 | + torch::Tensor const& b_scales, |
| 17 | + std::optional<torch::Tensor> const& bias) { |
18 | 18 | TORCH_CHECK(a_scales.dtype() == torch::kFloat32); |
19 | 19 | TORCH_CHECK(b_scales.dtype() == torch::kFloat32); |
20 | 20 |
|
21 | 21 | int M = a.size(0), N = b.size(1), K = a.size(1); |
22 | | - |
23 | | - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && |
24 | | - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { |
25 | | - // Standard per-tensor/per-token/per-channel scaling |
26 | | - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); |
27 | | - if (a.dtype() == torch::kFloat8_e4m3fn) { |
28 | | - vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); |
29 | | - } else { |
30 | | - TORCH_CHECK(a.dtype() == torch::kInt8); |
31 | | - vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); |
32 | | - } |
33 | | - } else { |
34 | | - using GroupShape = std::array<int64_t, 2>; |
35 | | - auto make_group_shape = [](torch::Tensor const& x, |
36 | | - torch::Tensor const& s) -> GroupShape { |
37 | | - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); |
38 | | - return {cuda_utils::ceil_div(x.size(0), s.size(0)), |
39 | | - cuda_utils::ceil_div(x.size(1), s.size(1))}; |
40 | | - }; |
41 | | - |
42 | | - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); |
43 | | - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); |
44 | | - |
45 | | - // 1x128 per-token group scales for activations |
46 | | - // 128x128 blockwise scales for weights |
47 | | - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && |
48 | | - b_scale_group_shape == GroupShape{128, 128} && |
49 | | - a.dtype() == torch::kFloat8_e4m3fn && |
50 | | - b.dtype() == torch::kFloat8_e4m3fn), |
51 | | - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" |
52 | | - "a_scale_group_shape must be [1, 128]. Got: [", |
53 | | - a_scale_group_shape[0], ", ", a_scale_group_shape[1], |
54 | | - "]\n" |
55 | | - "b_scale_group_shape must be [128, 128]. Got: [", |
56 | | - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); |
57 | | - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); |
58 | | - |
59 | | - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); |
60 | | - } |
61 | | -} |
62 | | - |
63 | | -void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, |
64 | | - torch::Tensor const& b, |
65 | | - torch::Tensor const& a_scales, |
66 | | - torch::Tensor const& b_scales, |
67 | | - torch::Tensor const& azp_adj, |
68 | | - std::optional<torch::Tensor> const& azp, |
69 | | - std::optional<torch::Tensor> const& bias) { |
70 | | - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); |
71 | | - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); |
72 | | - |
73 | | - vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj, |
74 | | - azp, bias); |
| 22 | + TORCH_CHECK( |
| 23 | + (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && |
| 24 | + (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), |
| 25 | + "Currently, block scaled fp8 gemm is not implemented for Blackwell"); |
| 26 | + |
| 27 | + // Standard per-tensor/per-token/per-channel scaling |
| 28 | + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); |
| 29 | + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, |
| 30 | + "Currently, only fp8 gemm is implemented for Blackwell"); |
| 31 | + vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); |
75 | 32 | } |
76 | 33 |
|
77 | 34 | #endif |
0 commit comments