Skip to content

Commit fcd6ab4

Browse files
jwfrommmeta-codesync[bot]
authored andcommitted
FP8 Convolution Kernel (#4994)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2009 Pull Request resolved: #4994 Initial implementation of Blackwell FP8 Convolution kernel in FBGEMM. Still requires tuning and testing but at least basic functionality is now supported. Reviewed By: zjing14 Differential Revision: D84378621 fbshipit-source-id: c80162bb887a94541ee691060f60aee34ab4b2ec
1 parent bd3243f commit fcd6ab4

File tree

4 files changed

+376
-0
lines changed

4 files changed

+376
-0
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <cutlass/util/host_tensor.h>
12+
#include <cutlass/util/packed_stride.hpp>
13+
14+
// clang-format off
15+
// The fixed ordering of the headers is required for CUTLASS 3.2+
16+
#include <cute/tensor.hpp>
17+
#include <cutlass/cutlass.h>
18+
#include <cutlass/conv/collective/collective_builder.hpp>
19+
#include <cutlass/conv/convnd_problem_shape.hpp>
20+
#include <cutlass/conv/convolution.h>
21+
#include <cutlass/conv/device/conv_universal_adapter.hpp>
22+
#include <cutlass/conv/dispatch_policy.hpp>
23+
#include <cutlass/conv/kernel/conv_universal.hpp>
24+
#include <cutlass/epilogue/collective/collective_builder.hpp>
25+
// clang-format on
26+
27+
namespace fbgemm_gpu {
28+
29+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
30+
31+
// Cutlass FP8 convolution kernel for SM100 (Blackwell architecture)
32+
template <
33+
int TB_M,
34+
int TB_N,
35+
int TB_K,
36+
int TBS_M,
37+
int TBS_N,
38+
int TBS_K>
39+
at::Tensor f8f8bf16_conv_impl(
40+
at::Tensor activation, // FP8 - NDHWC layout
41+
at::Tensor filter, // FP8 - KTRSC layout
42+
at::Tensor scale,
43+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
44+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
45+
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]
46+
47+
// Extract dimensions from activation (NDHWC)
48+
TORCH_CHECK(activation.dim() == 5, "Activation must be 5D tensor (NDHWC)");
49+
TORCH_CHECK(filter.dim() == 5, "Filter must be 5D tensor (KTRSC)");
50+
51+
int n = activation.size(0);
52+
int d = activation.size(1);
53+
int h = activation.size(2);
54+
int w = activation.size(3);
55+
int c = activation.size(4);
56+
57+
// Extract dimensions from filter (KTRSC)
58+
int k = filter.size(0);
59+
int t = filter.size(1);
60+
int r = filter.size(2);
61+
int s = filter.size(3);
62+
63+
TORCH_CHECK(
64+
filter.size(4) == c, "Filter channels must match activation channels");
65+
66+
// Extract padding, stride, dilation
67+
TORCH_CHECK(padding.size() == 3, "Padding must have 3 elements");
68+
TORCH_CHECK(stride.size() == 3, "Stride must have 3 elements");
69+
TORCH_CHECK(dilation.size() == 3, "Dilation must have 3 elements");
70+
71+
int pad_d = padding[0];
72+
int pad_h = padding[1];
73+
int pad_w = padding[2];
74+
75+
int stride_d = stride[0];
76+
int stride_h = stride[1];
77+
int stride_w = stride[2];
78+
79+
int dilation_d = dilation[0];
80+
int dilation_h = dilation[1];
81+
int dilation_w = dilation[2];
82+
83+
// Calculate output dimensions
84+
int z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
85+
int p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
86+
int q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
87+
88+
TORCH_CHECK(activation.is_cuda() && activation.is_contiguous());
89+
TORCH_CHECK(filter.is_cuda() && filter.is_contiguous());
90+
91+
auto output =
92+
at::empty({n, z, p, q, k}, activation.options().dtype(at::kBFloat16));
93+
94+
using ElementAct = cutlass::float_e4m3_t;
95+
using LayoutA = cutlass::layout::TensorNDHWC;
96+
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value;
97+
98+
using ElementFlt = cutlass::float_e4m3_t;
99+
using LayoutB = cutlass::layout::TensorNDHWC;
100+
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value;
101+
102+
using ElementOutput = cutlass::bfloat16_t;
103+
using LayoutC = cutlass::layout::TensorNDHWC;
104+
constexpr int AlignmentOutput =
105+
128 / cutlass::sizeof_bits<ElementOutput>::value;
106+
107+
using ElementAccumulator = float;
108+
using ElementCompute = float;
109+
using ArchTag = cutlass::arch::Sm100;
110+
using OperatorClass = cutlass::arch::OpClassTensorOp;
111+
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kFprop;
112+
113+
using TileShape = cute::
114+
Shape<cute::Int<TB_M>, cute::Int<TB_N>, cute::Shape<cute::Int<TB_K>>>;
115+
using ClusterShape =
116+
cute::Shape<cute::Int<TBS_M>, cute::Int<TBS_N>, cute::Int<TBS_K>>;
117+
118+
// Define Scale EVT.
119+
using Scale_ =
120+
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAccumulator>;
121+
122+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
123+
124+
using EpilogueCompute = cutlass::epilogue::fusion::Sm90Compute<
125+
cutlass::multiplies,
126+
ElementOutput,
127+
ElementAccumulator,
128+
cutlass::FloatRoundStyle::round_to_nearest>;
129+
130+
using EpilogueEVT =
131+
cutlass::epilogue::fusion::Sm90EVT<EpilogueCompute, Scale_, Accum>;
132+
133+
using CollectiveEpilogue =
134+
typename cutlass::epilogue::collective::CollectiveBuilder<
135+
ArchTag,
136+
OperatorClass,
137+
TileShape,
138+
ClusterShape,
139+
cutlass::epilogue::collective::EpilogueTileAuto,
140+
ElementAccumulator,
141+
ElementCompute,
142+
ElementOutput,
143+
LayoutC,
144+
AlignmentOutput,
145+
ElementOutput,
146+
LayoutC,
147+
AlignmentOutput,
148+
cutlass::epilogue::collective::EpilogueScheduleAuto,
149+
EpilogueEVT>::CollectiveOp;
150+
151+
using CollectiveMainloop =
152+
typename cutlass::conv::collective::CollectiveBuilder<
153+
ArchTag,
154+
OperatorClass,
155+
ConvOp,
156+
ElementAct,
157+
LayoutA,
158+
AlignmentAct,
159+
ElementFlt,
160+
LayoutB,
161+
AlignmentFlt,
162+
ElementAccumulator,
163+
TileShape,
164+
ClusterShape,
165+
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(
166+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
167+
cutlass::conv::collective::KernelScheduleAuto>::CollectiveOp;
168+
169+
using ProblemShape = cutlass::conv::ConvProblemShape<
170+
ConvOp,
171+
CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
172+
using ConvKernel = cutlass::conv::kernel::
173+
ConvUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
174+
175+
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
176+
177+
using StrideC = typename Conv::ConvKernel::StrideC;
178+
using StrideD = typename Conv::ConvKernel::StrideD;
179+
180+
ProblemShape problem_shape(
181+
cutlass::conv::Mode::kCrossCorrelation,
182+
{n, d, h, w, c},
183+
{k, t, r, s, c},
184+
{pad_d, pad_h, pad_w},
185+
{pad_d, pad_h, pad_w},
186+
{stride_d, stride_h, stride_w},
187+
{dilation_d, dilation_h, dilation_w},
188+
1 // group
189+
);
190+
191+
StrideC stride_C;
192+
StrideD stride_D;
193+
194+
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
195+
cute::get<0, i>(stride_C) =
196+
problem_shape.stride_C[ProblemShape::RankT - 2 - i];
197+
});
198+
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
199+
cute::get<0, i>(stride_D) =
200+
problem_shape.stride_C[ProblemShape::RankT - 2 - i];
201+
});
202+
203+
typename Conv::Arguments arguments{
204+
problem_shape,
205+
{reinterpret_cast<ElementAct*>(activation.data_ptr()),
206+
reinterpret_cast<ElementFlt*>(filter.data_ptr())},
207+
{{},
208+
reinterpret_cast<ElementOutput*>(output.data_ptr<at::BFloat16>()),
209+
stride_C,
210+
reinterpret_cast<ElementOutput*>(output.data_ptr<at::BFloat16>()),
211+
stride_D}};
212+
213+
arguments.epilogue.thread = {
214+
{{}, {reinterpret_cast<ElementAccumulator*>(scale.data_ptr())}},
215+
{}, // Accumulator
216+
{}, // Multiplies
217+
};
218+
219+
Conv conv;
220+
221+
size_t workspace_size = Conv::get_workspace_size(arguments);
222+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
223+
224+
cutlass::Status status = conv.can_implement(arguments);
225+
if (status != cutlass::Status::kSuccess) {
226+
throw std::runtime_error("cutlass cannot implement convolution");
227+
}
228+
229+
status = conv.initialize(arguments, workspace.get());
230+
if (status != cutlass::Status::kSuccess) {
231+
throw std::runtime_error("cutlass cannot initialize convolution");
232+
}
233+
234+
status = conv(at::cuda::getCurrentCUDAStream());
235+
if (status != cutlass::Status::kSuccess) {
236+
throw std::runtime_error(
237+
std::string("cutlass cannot run convolution: ") +
238+
cutlass::cutlassGetStatusString(status));
239+
}
240+
C10_CUDA_KERNEL_LAUNCH_CHECK();
241+
242+
return output;
243+
}
244+
245+
at::Tensor f8f8bf16_conv(
246+
at::Tensor activation, // FP8 - NDHWC layout
247+
at::Tensor filter, // FP8 - KTRSC layout
248+
at::Tensor scale,
249+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
250+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
251+
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]
252+
253+
return f8f8bf16_conv_impl<64, 64, 64, 1, 1, 1>(
254+
activation, filter, scale, padding, stride, dilation);
255+
}
256+
257+
#else
258+
259+
at::Tensor f8f8bf16_conv(
260+
at::Tensor activation,
261+
at::Tensor filter,
262+
at::Tensor scale,
263+
std::vector<int64_t> padding,
264+
std::vector<int64_t> stride,
265+
std::vector<int64_t> dilation) {
266+
throw std::runtime_error(
267+
"SM100 (Blackwell) architecture not supported. Requires CUTLASS 3.x with SM100 support.");
268+
}
269+
270+
#endif
271+
272+
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ at::Tensor f8f8bf16_tensorwise(
6868
double scale,
6969
bool use_fast_accum = true);
7070
at::Tensor f8f8bf16_lite(at::Tensor XQ, at::Tensor WQ, at::Tensor scale);
71+
at::Tensor f8f8bf16_conv(
72+
at::Tensor activation,
73+
at::Tensor filter,
74+
at::Tensor scale,
75+
std::vector<int64_t> padding,
76+
std::vector<int64_t> stride,
77+
std::vector<int64_t> dilation);
7178
std::vector<at::Tensor> bf16bf16bf16_grouped(
7279
at::TensorList X,
7380
at::TensorList W);
@@ -317,6 +324,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
317324
m.impl("bf16fp8bf16_fast_gemv", bf16fp8bf16_fast_gemv);
318325
m.impl("fp8fp8bf16_fast_gemv", fp8fp8bf16_fast_gemv);
319326
m.impl("f8f8bf16_lite", f8f8bf16_lite);
327+
m.impl("f8f8bf16_conv", f8f8bf16_conv);
320328
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
321329
m.impl("f8i4bf16_shuffled", f8i4bf16_shuffled);
322330
m.impl("bf16i4bf16_shuffled", bf16i4bf16_shuffled);
@@ -614,6 +622,44 @@ at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) {
614622
return Y;
615623
}
616624

625+
at::Tensor f8f8bf16_conv_meta(
626+
at::Tensor activation,
627+
at::Tensor filter,
628+
at::Tensor /* scale */,
629+
std::vector<int64_t> padding,
630+
std::vector<int64_t> stride,
631+
std::vector<int64_t> dilation) {
632+
TORCH_CHECK(activation.dim() == 5, "Activation must be 5D tensor (NDHWC)");
633+
TORCH_CHECK(filter.dim() == 5, "Filter must be 5D tensor (KTRSC)");
634+
635+
const at::SymInt n = activation.sym_size(0);
636+
const at::SymInt d = activation.sym_size(1);
637+
const at::SymInt h = activation.sym_size(2);
638+
const at::SymInt w = activation.sym_size(3);
639+
const at::SymInt k = filter.sym_size(0);
640+
const at::SymInt t = filter.sym_size(1);
641+
const at::SymInt r = filter.sym_size(2);
642+
const at::SymInt s = filter.sym_size(3);
643+
644+
int pad_d = padding[0];
645+
int pad_h = padding[1];
646+
int pad_w = padding[2];
647+
int stride_d = stride[0];
648+
int stride_h = stride[1];
649+
int stride_w = stride[2];
650+
int dilation_d = dilation[0];
651+
int dilation_h = dilation[1];
652+
int dilation_w = dilation[2];
653+
654+
at::SymInt z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
655+
at::SymInt p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
656+
at::SymInt q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
657+
658+
auto Y = at::empty_symint(
659+
{n, z, p, q, k}, activation.options().dtype(at::kBFloat16));
660+
return Y;
661+
}
662+
617663
at::Tensor f8i4bf16_rowwise_meta(
618664
at::Tensor XQ, // FP8
619665
at::Tensor WQ, // INT4
@@ -843,6 +889,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
843889
m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched_meta);
844890
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
845891
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
892+
m.impl("f8f8bf16_conv", f8f8bf16_conv_meta);
846893
m.impl("scaled_fp4_quant", scaled_fp4_quant_meta);
847894
m.impl("preshuffle_i4", preshuffle_i4_meta);
848895
m.impl("f8i4bf16_shuffled", f8i4bf16_shuffled_meta);

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
5050
m.def(
5151
"fp8fp8bf16_fast_gemv(Tensor X, Tensor W, Tensor x_scale, Tensor w_scale, bool is_batched=False) -> Tensor");
5252
m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor");
53+
m.def(
54+
"f8f8bf16_conv(Tensor activation, Tensor filter, Tensor scale, int[] padding, int[] stride, int[] dilation) -> Tensor");
5355
m.def(
5456
"bf16i4bf16_rowwise(Tensor X, Tensor W, Tensor w_scale_group, Tensor w_zero_group) -> Tensor");
5557
m.def(

0 commit comments

Comments
 (0)