|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <cutlass/util/host_tensor.h> |
| 12 | +#include <cutlass/util/packed_stride.hpp> |
| 13 | + |
| 14 | +// clang-format off |
| 15 | +// The fixed ordering of the headers is required for CUTLASS 3.2+ |
| 16 | +#include <cute/tensor.hpp> |
| 17 | +#include <cutlass/cutlass.h> |
| 18 | +#include <cutlass/conv/collective/collective_builder.hpp> |
| 19 | +#include <cutlass/conv/convnd_problem_shape.hpp> |
| 20 | +#include <cutlass/conv/convolution.h> |
| 21 | +#include <cutlass/conv/device/conv_universal_adapter.hpp> |
| 22 | +#include <cutlass/conv/dispatch_policy.hpp> |
| 23 | +#include <cutlass/conv/kernel/conv_universal.hpp> |
| 24 | +#include <cutlass/epilogue/collective/collective_builder.hpp> |
| 25 | +// clang-format on |
| 26 | + |
| 27 | +namespace fbgemm_gpu { |
| 28 | + |
| 29 | +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) |
| 30 | + |
| 31 | +// Cutlass FP8 convolution kernel for SM100 (Blackwell architecture) |
| 32 | +template < |
| 33 | + int TB_M, |
| 34 | + int TB_N, |
| 35 | + int TB_K, |
| 36 | + int TBS_M, |
| 37 | + int TBS_N, |
| 38 | + int TBS_K> |
| 39 | +at::Tensor f8f8bf16_conv_impl( |
| 40 | + at::Tensor activation, // FP8 - NDHWC layout |
| 41 | + at::Tensor filter, // FP8 - KTRSC layout |
| 42 | + at::Tensor scale, |
| 43 | + std::vector<int64_t> padding, // [pad_d, pad_h, pad_w] |
| 44 | + std::vector<int64_t> stride, // [stride_d, stride_h, stride_w] |
| 45 | + std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w] |
| 46 | + |
| 47 | + // Extract dimensions from activation (NDHWC) |
| 48 | + TORCH_CHECK(activation.dim() == 5, "Activation must be 5D tensor (NDHWC)"); |
| 49 | + TORCH_CHECK(filter.dim() == 5, "Filter must be 5D tensor (KTRSC)"); |
| 50 | + |
| 51 | + int n = activation.size(0); |
| 52 | + int d = activation.size(1); |
| 53 | + int h = activation.size(2); |
| 54 | + int w = activation.size(3); |
| 55 | + int c = activation.size(4); |
| 56 | + |
| 57 | + // Extract dimensions from filter (KTRSC) |
| 58 | + int k = filter.size(0); |
| 59 | + int t = filter.size(1); |
| 60 | + int r = filter.size(2); |
| 61 | + int s = filter.size(3); |
| 62 | + |
| 63 | + TORCH_CHECK( |
| 64 | + filter.size(4) == c, "Filter channels must match activation channels"); |
| 65 | + |
| 66 | + // Extract padding, stride, dilation |
| 67 | + TORCH_CHECK(padding.size() == 3, "Padding must have 3 elements"); |
| 68 | + TORCH_CHECK(stride.size() == 3, "Stride must have 3 elements"); |
| 69 | + TORCH_CHECK(dilation.size() == 3, "Dilation must have 3 elements"); |
| 70 | + |
| 71 | + int pad_d = padding[0]; |
| 72 | + int pad_h = padding[1]; |
| 73 | + int pad_w = padding[2]; |
| 74 | + |
| 75 | + int stride_d = stride[0]; |
| 76 | + int stride_h = stride[1]; |
| 77 | + int stride_w = stride[2]; |
| 78 | + |
| 79 | + int dilation_d = dilation[0]; |
| 80 | + int dilation_h = dilation[1]; |
| 81 | + int dilation_w = dilation[2]; |
| 82 | + |
| 83 | + // Calculate output dimensions |
| 84 | + int z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d; |
| 85 | + int p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h; |
| 86 | + int q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w; |
| 87 | + |
| 88 | + TORCH_CHECK(activation.is_cuda() && activation.is_contiguous()); |
| 89 | + TORCH_CHECK(filter.is_cuda() && filter.is_contiguous()); |
| 90 | + |
| 91 | + auto output = |
| 92 | + at::empty({n, z, p, q, k}, activation.options().dtype(at::kBFloat16)); |
| 93 | + |
| 94 | + using ElementAct = cutlass::float_e4m3_t; |
| 95 | + using LayoutA = cutlass::layout::TensorNDHWC; |
| 96 | + constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; |
| 97 | + |
| 98 | + using ElementFlt = cutlass::float_e4m3_t; |
| 99 | + using LayoutB = cutlass::layout::TensorNDHWC; |
| 100 | + constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; |
| 101 | + |
| 102 | + using ElementOutput = cutlass::bfloat16_t; |
| 103 | + using LayoutC = cutlass::layout::TensorNDHWC; |
| 104 | + constexpr int AlignmentOutput = |
| 105 | + 128 / cutlass::sizeof_bits<ElementOutput>::value; |
| 106 | + |
| 107 | + using ElementAccumulator = float; |
| 108 | + using ElementCompute = float; |
| 109 | + using ArchTag = cutlass::arch::Sm100; |
| 110 | + using OperatorClass = cutlass::arch::OpClassTensorOp; |
| 111 | + constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kFprop; |
| 112 | + |
| 113 | + using TileShape = cute:: |
| 114 | + Shape<cute::Int<TB_M>, cute::Int<TB_N>, cute::Shape<cute::Int<TB_K>>>; |
| 115 | + using ClusterShape = |
| 116 | + cute::Shape<cute::Int<TBS_M>, cute::Int<TBS_N>, cute::Int<TBS_K>>; |
| 117 | + |
| 118 | + // Define Scale EVT. |
| 119 | + using Scale_ = |
| 120 | + cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAccumulator>; |
| 121 | + |
| 122 | + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; |
| 123 | + |
| 124 | + using EpilogueCompute = cutlass::epilogue::fusion::Sm90Compute< |
| 125 | + cutlass::multiplies, |
| 126 | + ElementOutput, |
| 127 | + ElementAccumulator, |
| 128 | + cutlass::FloatRoundStyle::round_to_nearest>; |
| 129 | + |
| 130 | + using EpilogueEVT = |
| 131 | + cutlass::epilogue::fusion::Sm90EVT<EpilogueCompute, Scale_, Accum>; |
| 132 | + |
| 133 | + using CollectiveEpilogue = |
| 134 | + typename cutlass::epilogue::collective::CollectiveBuilder< |
| 135 | + ArchTag, |
| 136 | + OperatorClass, |
| 137 | + TileShape, |
| 138 | + ClusterShape, |
| 139 | + cutlass::epilogue::collective::EpilogueTileAuto, |
| 140 | + ElementAccumulator, |
| 141 | + ElementCompute, |
| 142 | + ElementOutput, |
| 143 | + LayoutC, |
| 144 | + AlignmentOutput, |
| 145 | + ElementOutput, |
| 146 | + LayoutC, |
| 147 | + AlignmentOutput, |
| 148 | + cutlass::epilogue::collective::EpilogueScheduleAuto, |
| 149 | + EpilogueEVT>::CollectiveOp; |
| 150 | + |
| 151 | + using CollectiveMainloop = |
| 152 | + typename cutlass::conv::collective::CollectiveBuilder< |
| 153 | + ArchTag, |
| 154 | + OperatorClass, |
| 155 | + ConvOp, |
| 156 | + ElementAct, |
| 157 | + LayoutA, |
| 158 | + AlignmentAct, |
| 159 | + ElementFlt, |
| 160 | + LayoutB, |
| 161 | + AlignmentFlt, |
| 162 | + ElementAccumulator, |
| 163 | + TileShape, |
| 164 | + ClusterShape, |
| 165 | + cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>( |
| 166 | + sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 167 | + cutlass::conv::collective::KernelScheduleAuto>::CollectiveOp; |
| 168 | + |
| 169 | + using ProblemShape = cutlass::conv::ConvProblemShape< |
| 170 | + ConvOp, |
| 171 | + CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>; |
| 172 | + using ConvKernel = cutlass::conv::kernel:: |
| 173 | + ConvUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>; |
| 174 | + |
| 175 | + using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>; |
| 176 | + |
| 177 | + using StrideC = typename Conv::ConvKernel::StrideC; |
| 178 | + using StrideD = typename Conv::ConvKernel::StrideD; |
| 179 | + |
| 180 | + ProblemShape problem_shape( |
| 181 | + cutlass::conv::Mode::kCrossCorrelation, |
| 182 | + {n, d, h, w, c}, |
| 183 | + {k, t, r, s, c}, |
| 184 | + {pad_d, pad_h, pad_w}, |
| 185 | + {pad_d, pad_h, pad_w}, |
| 186 | + {stride_d, stride_h, stride_w}, |
| 187 | + {dilation_d, dilation_h, dilation_w}, |
| 188 | + 1 // group |
| 189 | + ); |
| 190 | + |
| 191 | + StrideC stride_C; |
| 192 | + StrideD stride_D; |
| 193 | + |
| 194 | + cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) { |
| 195 | + cute::get<0, i>(stride_C) = |
| 196 | + problem_shape.stride_C[ProblemShape::RankT - 2 - i]; |
| 197 | + }); |
| 198 | + cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) { |
| 199 | + cute::get<0, i>(stride_D) = |
| 200 | + problem_shape.stride_C[ProblemShape::RankT - 2 - i]; |
| 201 | + }); |
| 202 | + |
| 203 | + typename Conv::Arguments arguments{ |
| 204 | + problem_shape, |
| 205 | + {reinterpret_cast<ElementAct*>(activation.data_ptr()), |
| 206 | + reinterpret_cast<ElementFlt*>(filter.data_ptr())}, |
| 207 | + {{}, |
| 208 | + reinterpret_cast<ElementOutput*>(output.data_ptr<at::BFloat16>()), |
| 209 | + stride_C, |
| 210 | + reinterpret_cast<ElementOutput*>(output.data_ptr<at::BFloat16>()), |
| 211 | + stride_D}}; |
| 212 | + |
| 213 | + arguments.epilogue.thread = { |
| 214 | + {{}, {reinterpret_cast<ElementAccumulator*>(scale.data_ptr())}}, |
| 215 | + {}, // Accumulator |
| 216 | + {}, // Multiplies |
| 217 | + }; |
| 218 | + |
| 219 | + Conv conv; |
| 220 | + |
| 221 | + size_t workspace_size = Conv::get_workspace_size(arguments); |
| 222 | + cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
| 223 | + |
| 224 | + cutlass::Status status = conv.can_implement(arguments); |
| 225 | + if (status != cutlass::Status::kSuccess) { |
| 226 | + throw std::runtime_error("cutlass cannot implement convolution"); |
| 227 | + } |
| 228 | + |
| 229 | + status = conv.initialize(arguments, workspace.get()); |
| 230 | + if (status != cutlass::Status::kSuccess) { |
| 231 | + throw std::runtime_error("cutlass cannot initialize convolution"); |
| 232 | + } |
| 233 | + |
| 234 | + status = conv(at::cuda::getCurrentCUDAStream()); |
| 235 | + if (status != cutlass::Status::kSuccess) { |
| 236 | + throw std::runtime_error( |
| 237 | + std::string("cutlass cannot run convolution: ") + |
| 238 | + cutlass::cutlassGetStatusString(status)); |
| 239 | + } |
| 240 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 241 | + |
| 242 | + return output; |
| 243 | +} |
| 244 | + |
| 245 | +at::Tensor f8f8bf16_conv( |
| 246 | + at::Tensor activation, // FP8 - NDHWC layout |
| 247 | + at::Tensor filter, // FP8 - KTRSC layout |
| 248 | + at::Tensor scale, |
| 249 | + std::vector<int64_t> padding, // [pad_d, pad_h, pad_w] |
| 250 | + std::vector<int64_t> stride, // [stride_d, stride_h, stride_w] |
| 251 | + std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w] |
| 252 | + |
| 253 | + return f8f8bf16_conv_impl<64, 64, 64, 1, 1, 1>( |
| 254 | + activation, filter, scale, padding, stride, dilation); |
| 255 | +} |
| 256 | + |
| 257 | +#else |
| 258 | + |
| 259 | +at::Tensor f8f8bf16_conv( |
| 260 | + at::Tensor activation, |
| 261 | + at::Tensor filter, |
| 262 | + at::Tensor scale, |
| 263 | + std::vector<int64_t> padding, |
| 264 | + std::vector<int64_t> stride, |
| 265 | + std::vector<int64_t> dilation) { |
| 266 | + throw std::runtime_error( |
| 267 | + "SM100 (Blackwell) architecture not supported. Requires CUTLASS 3.x with SM100 support."); |
| 268 | +} |
| 269 | + |
| 270 | +#endif |
| 271 | + |
| 272 | +} // namespace fbgemm_gpu |
0 commit comments