|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include <torch/all.h> |
| 18 | + |
| 19 | +#include <ATen/cuda/CUDAContext.h> |
| 20 | +#include <c10/cuda/CUDAGuard.h> |
| 21 | + |
| 22 | +#include "cutlass_extensions/common.hpp" |
| 23 | + |
| 24 | +#include "cutlass/cutlass.h" |
| 25 | + |
| 26 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 27 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 28 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 29 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 30 | + |
| 31 | +#include "cutlass/util/packed_stride.hpp" |
| 32 | + |
| 33 | +using namespace cute; |
| 34 | + |
| 35 | +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) |
| 36 | +// Kernel Perf config |
| 37 | +template <typename T> |
| 38 | +struct KernelTraits; |
| 39 | + |
| 40 | +template <> |
| 41 | +struct KernelTraits<float> { |
| 42 | + using MmaTileShape = Shape<_128, _128, _256>; |
| 43 | + using ClusterShape = Shape<_1, _1, _1>; |
| 44 | + using PerSmTileShape_MNK = Shape<_128, _128, _256>; |
| 45 | +}; |
| 46 | + |
| 47 | +template <> |
| 48 | +struct KernelTraits<cutlass::half_t> { |
| 49 | + using MmaTileShape = Shape<_256, _256, _256>; |
| 50 | + using ClusterShape = Shape<_4, _4, _1>; |
| 51 | + using PerSmTileShape_MNK = Shape<_128, _256, _256>; |
| 52 | +}; |
| 53 | + |
| 54 | +template <> |
| 55 | +struct KernelTraits<cutlass::bfloat16_t> { |
| 56 | + using MmaTileShape = Shape<_256, _256, _256>; |
| 57 | + using ClusterShape = Shape<_4, _4, _1>; |
| 58 | + using PerSmTileShape_MNK = Shape<_128, _256, _256>; |
| 59 | +}; |
| 60 | + |
| 61 | +template <typename T> |
| 62 | +struct Fp4GemmSm100 { |
| 63 | + // A matrix configuration |
| 64 | + using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; |
| 65 | + using LayoutATag = cutlass::layout::RowMajor; |
| 66 | + static constexpr int AlignmentA = 32; |
| 67 | + |
| 68 | + // B matrix configuration |
| 69 | + using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; |
| 70 | + using LayoutBTag = cutlass::layout::ColumnMajor; |
| 71 | + static constexpr int AlignmentB = 32; |
| 72 | + |
| 73 | + // C/D matrix configuration |
| 74 | + using ElementD = T; |
| 75 | + using ElementC = T; |
| 76 | + using LayoutCTag = cutlass::layout::RowMajor; |
| 77 | + using LayoutDTag = cutlass::layout::RowMajor; |
| 78 | + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; |
| 79 | + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; |
| 80 | + // Kernel functional config |
| 81 | + using ElementAccumulator = float; |
| 82 | + using ArchTag = cutlass::arch::Sm100; |
| 83 | + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; |
| 84 | + |
| 85 | + // Kernel Perf config |
| 86 | + using MmaTileShape = typename KernelTraits<T>::MmaTileShape; |
| 87 | + using ClusterShape = typename KernelTraits<T>::ClusterShape; |
| 88 | + using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK; |
| 89 | + |
| 90 | + using CollectiveEpilogue = |
| 91 | + typename cutlass::epilogue::collective::CollectiveBuilder< |
| 92 | + ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, |
| 93 | + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, |
| 94 | + ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, |
| 95 | + LayoutDTag, AlignmentD, |
| 96 | + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; |
| 97 | + |
| 98 | + using CollectiveMainloop = |
| 99 | + typename cutlass::gemm::collective::CollectiveBuilder< |
| 100 | + ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, |
| 101 | + LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, |
| 102 | + ClusterShape, |
| 103 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| 104 | + sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 105 | + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; |
| 106 | + |
| 107 | + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< |
| 108 | + Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; |
| 109 | + using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 110 | + using StrideA = typename Gemm::GemmKernel::StrideA; |
| 111 | + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); |
| 112 | + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; |
| 113 | + using StrideB = typename Gemm::GemmKernel::StrideB; |
| 114 | + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); |
| 115 | + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; |
| 116 | + using StrideC = typename Gemm::GemmKernel::StrideC; |
| 117 | + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); |
| 118 | + using StrideD = typename Gemm::GemmKernel::StrideD; |
| 119 | + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); |
| 120 | +}; |
| 121 | + |
| 122 | +template <typename T> |
| 123 | +typename T::Gemm::Arguments args_from_options( |
| 124 | + at::Tensor& D, at::Tensor const& A, at::Tensor const& B, |
| 125 | + at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, |
| 126 | + int64_t M, int64_t N, int64_t K) { |
| 127 | + using ElementA = typename T::Gemm::ElementA; |
| 128 | + using ElementB = typename T::Gemm::ElementB; |
| 129 | + using ElementSFA = cutlass::float_ue4m3_t; |
| 130 | + using ElementSFB = cutlass::float_ue4m3_t; |
| 131 | + using ElementD = typename T::Gemm::ElementD; |
| 132 | + using ElementCompute = float; |
| 133 | + using StrideA = typename T::StrideA; |
| 134 | + using StrideB = typename T::StrideB; |
| 135 | + using StrideD = typename T::StrideD; |
| 136 | + using Sm100BlkScaledConfig = |
| 137 | + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; |
| 138 | + |
| 139 | + int m = static_cast<int>(M); |
| 140 | + int n = static_cast<int>(N); |
| 141 | + int k = static_cast<int>(K); |
| 142 | + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); |
| 143 | + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); |
| 144 | + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); |
| 145 | + |
| 146 | + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA( |
| 147 | + cute::make_shape(m, n, k, 1)); |
| 148 | + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( |
| 149 | + cute::make_shape(m, n, k, 1)); |
| 150 | + |
| 151 | + typename T::Gemm::Arguments arguments{ |
| 152 | + cutlass::gemm::GemmUniversalMode::kGemm, |
| 153 | + {m, n, k, 1}, |
| 154 | + {// Mainloop arguments |
| 155 | + static_cast<ElementA const*>(A.data_ptr()), stride_A, |
| 156 | + static_cast<ElementB const*>(B.data_ptr()), stride_B, |
| 157 | + static_cast<ElementSFA const*>(A_sf.data_ptr()), layout_SFA, |
| 158 | + static_cast<ElementSFB const*>(B_sf.data_ptr()), layout_SFB}, |
| 159 | + { // Epilogue arguments |
| 160 | + {}, // epilogue.thread |
| 161 | + static_cast<ElementD const*>(D.data_ptr()), |
| 162 | + stride_D, |
| 163 | + static_cast<ElementD*>(D.data_ptr()), |
| 164 | + stride_D}}; |
| 165 | + auto& fusion_args = arguments.epilogue.thread; |
| 166 | + fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr()); |
| 167 | + return arguments; |
| 168 | +} |
| 169 | + |
| 170 | +template <typename T> |
| 171 | +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, |
| 172 | + at::Tensor const& A_sf, at::Tensor const& B_sf, |
| 173 | + at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, |
| 174 | + cudaStream_t stream) { |
| 175 | + typename Fp4GemmSm100<T>::Gemm gemm; |
| 176 | + |
| 177 | + auto arguments = |
| 178 | + args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k); |
| 179 | + |
| 180 | + size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments); |
| 181 | + auto const workspace_options = |
| 182 | + torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); |
| 183 | + auto workspace = torch::empty(workspace_size, workspace_options); |
| 184 | + |
| 185 | + CUTLASS_CHECK(gemm.can_implement(arguments)); |
| 186 | + |
| 187 | + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); |
| 188 | + |
| 189 | + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); |
| 190 | +} |
| 191 | +#else |
| 192 | +template <typename T> |
| 193 | +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, |
| 194 | + at::Tensor const& A_sf, at::Tensor const& B_sf, |
| 195 | + at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, |
| 196 | + cudaStream_t stream) { |
| 197 | + TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " |
| 198 | + "a CUTLASS 3.8 source directory to enable support."); |
| 199 | +} |
| 200 | +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) |
| 201 | + |
| 202 | +#define CHECK_TYPE(x, st, m) \ |
| 203 | + TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) |
| 204 | +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") |
| 205 | +#define CHECK_CONTIGUOUS(x, m) \ |
| 206 | + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") |
| 207 | +#define CHECK_INPUT(x, st, m) \ |
| 208 | + CHECK_TH_CUDA(x, m); \ |
| 209 | + CHECK_CONTIGUOUS(x, m); \ |
| 210 | + CHECK_TYPE(x, st, m) |
| 211 | + |
| 212 | +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; |
| 213 | +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; |
| 214 | + |
| 215 | +void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, |
| 216 | + torch::Tensor const& B, |
| 217 | + torch::Tensor const& A_sf, |
| 218 | + torch::Tensor const& B_sf, |
| 219 | + torch::Tensor const& alpha) { |
| 220 | + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); |
| 221 | + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); |
| 222 | + |
| 223 | + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); |
| 224 | + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); |
| 225 | + |
| 226 | + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); |
| 227 | + |
| 228 | + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); |
| 229 | + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); |
| 230 | + TORCH_CHECK(A.sizes()[1] == B.sizes()[1], |
| 231 | + "a and b shapes cannot be multiplied (", A.sizes()[0], "x", |
| 232 | + A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); |
| 233 | + |
| 234 | + auto const m = A.sizes()[0]; |
| 235 | + auto const n = B.sizes()[0]; |
| 236 | + auto const k = A.sizes()[1] * 2; |
| 237 | + |
| 238 | + constexpr int alignment = 32; |
| 239 | + TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, |
| 240 | + ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], |
| 241 | + "), k: ", k, "."); |
| 242 | + TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, |
| 243 | + ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); |
| 244 | + |
| 245 | + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; |
| 246 | + int rounded_m = round_up(m, 128); |
| 247 | + int rounded_n = round_up(n, 128); |
| 248 | + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an |
| 249 | + // integer. |
| 250 | + int rounded_k = round_up(k / 16, 4); |
| 251 | + |
| 252 | + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); |
| 253 | + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); |
| 254 | + TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], |
| 255 | + "scale_a and scale_b shapes cannot be multiplied (", |
| 256 | + A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], |
| 257 | + "x", B_sf.sizes()[1], ")"); |
| 258 | + TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, |
| 259 | + "scale_a must be padded and swizzled to a shape (", rounded_m, |
| 260 | + "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", |
| 261 | + A_sf.sizes()[1], ")"); |
| 262 | + TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, |
| 263 | + "scale_b must be padded and swizzled to a shape (", rounded_n, |
| 264 | + "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", |
| 265 | + B_sf.sizes()[1], ")"); |
| 266 | + |
| 267 | + auto out_dtype = D.dtype(); |
| 268 | + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; |
| 269 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); |
| 270 | + |
| 271 | + if (out_dtype == at::ScalarType::Half) { |
| 272 | + runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); |
| 273 | + } else if (out_dtype == at::ScalarType::BFloat16) { |
| 274 | + runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); |
| 275 | + } else if (out_dtype == at::ScalarType::Float) { |
| 276 | + runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); |
| 277 | + } else { |
| 278 | + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); |
| 279 | + } |
| 280 | +} |
0 commit comments