Skip to content

Commit

Permalink
Refactor FP8 grouped GEMM to prepare cudagraph support (#3369)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#460

Refactor FP8 grouped GEMM to extract grouped GEMM arguments and configurations ahead of the grouped gemm kernel, such that those can be reused for another cuda kernel argument setup on device

Differential Revision: D65548954
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Nov 14, 2024
1 parent 6dd2d31 commit 9b4b04b
Showing 1 changed file with 134 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,64 +25,33 @@ namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

namespace GroupedGemmArgs {
using ProblemShape =
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
using ElementInputA = cutlass::float_e4m3_t;
using ElementInputB = cutlass::float_e4m3_t;
using ElementOutput = cutlass::bfloat16_t;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
// Template structure to encapsulate configurations
template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
std::vector<at::Tensor> f8f8bf16_grouped_impl(
const std::vector<at::Tensor>& XQ, // FP8
const std::vector<at::Tensor>& WQ, // FP8
const std::vector<at::Tensor>& scale) {
int problem_count = XQ.size();
TORCH_CHECK(WQ.size() == problem_count);
if (problem_count == 0) {
return std::vector<at::Tensor>();
}

using ProblemShape =
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>; // <M,N,K>
// per group
using ElementInputA =
cutlass::float_e4m3_t; // Element type for A matrix operand
using ElementInputB =
cutlass::float_e4m3_t; // Element type for B matrix operand
using ElementOutput =
cutlass::bfloat16_t; // Element type for C and D matrix operands

using LayoutInputA =
cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementInputA>::value; // Alignment of A matrix
// in units of elements
// (up to 16 bytes)

using LayoutInputB =
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementInputB>::value; // Alignment of B matrix
// in units of elements
// (up to 16 bytes)

using LayoutOutput =
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentD =
128 / cutlass::sizeof_bits<ElementOutput>::value; // Alignment of C matrix
// in units of elements
// (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size

bool PONG>
struct GroupedGemmConfigs {
using TileShape =
cute::Shape<cute::Int<TB_M>, cute::Int<TB_N>, cute::Int<TB_K>>;
using ClusterShape =
cute::Shape<cute::Int<TBS_M>, cute::Int<TBS_N>, cute::Int<TBS_K>>;
using CooperativeSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using PongSchedule =
Expand All @@ -91,16 +60,10 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using PongEpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;

using KernelSchedule =
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
using TileShape =
cute::Shape<cute::Int<TB_M>, cute::Int<TB_N>, cute::Int<TB_K>>;
using ClusterShape =
cute::Shape<cute::Int<TBS_M>, cute::Int<TBS_N>, cute::Int<TBS_K>>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
Expand All @@ -112,25 +75,24 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
ElementAccumulator,
ElementOutput,
LayoutOutput*,
AlignmentD,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput,
LayoutOutput*,
AlignmentD,
128 / cutlass::sizeof_bits<ElementOutput>::value,
EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<
ElementOutput,
ElementAccumulator>>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA*,
AlignmentA,
128 / cutlass::sizeof_bits<ElementInputA>::value,
ElementInputB,
LayoutInputB*,
AlignmentB,
128 / cutlass::sizeof_bits<ElementInputB>::value,
ElementAccumulator,
TileShape,
ClusterShape,
Expand All @@ -139,12 +101,54 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::
GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::InternalStrideA;
using StrideInputB = typename Gemm::GemmKernel::InternalStrideB;
using StrideOutput = typename Gemm::GemmKernel::InternalStrideD;
};
} // namespace GroupedGemmArgs

template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
std::vector<at::Tensor> f8f8bf16_grouped_impl(
const std::vector<at::Tensor>& XQ, // FP8
const std::vector<at::Tensor>& WQ, // FP8
const std::vector<at::Tensor>& scale) {
int problem_count = XQ.size();
TORCH_CHECK(WQ.size() == problem_count);
if (problem_count == 0) {
return std::vector<at::Tensor>();
}
using GroupedGemmConfigs = GroupedGemmArgs::
GroupedGemmConfigs<TB_M, TB_N, TB_K, TBS_M, TBS_N, TBS_K, PONG>;

constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<
GroupedGemmArgs::ElementInputA>::value; // Alignment of A matrix
// in units of elements
// (up to 16 bytes)

constexpr int AlignmentB =
128 /
cutlass::sizeof_bits<
GroupedGemmArgs::ElementInputB>::value; // Alignment of B matrix
// in units of elements
// (up to 16 bytes)

constexpr int AlignmentD =
128 /
cutlass::sizeof_bits<
GroupedGemmArgs::ElementOutput>::value; // Alignment of C matrix
// in units of elements
// (up to 16 bytes)

int64_t output_offset = 0;
int64_t total_output_size = 0;
Expand All @@ -155,9 +159,10 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
int64_t* output_ptr = output_args.data_ptr<int64_t>();

const int64_t problem_shape_size =
problem_count * ((int64_t)sizeof(ProblemShape::UnderlyingProblemShape));
const int64_t stride_size = problem_count * ((int64_t)sizeof(StrideInputA));
const int64_t problem_shape_size = problem_count *
((int64_t)sizeof(GroupedGemmArgs::ProblemShape::UnderlyingProblemShape));
const int64_t stride_size = problem_count *
((int64_t)sizeof(typename GroupedGemmConfigs::StrideInputA));

at::Tensor input_args = at::empty(
{problem_count * 3 + problem_shape_size + stride_size * 3},
Expand All @@ -174,15 +179,17 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)) +
problem_shape_size);

ProblemShape::UnderlyingProblemShape* problem_shape_ptr =
reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape* problem_shape_ptr =
reinterpret_cast<GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
StrideInputA* stride_input_A_ptr =
reinterpret_cast<StrideInputA*>(stride_buf);
StrideInputB* stride_input_B_ptr =
reinterpret_cast<StrideInputB*>(stride_buf + stride_size);
StrideOutput* stride_output_ptr =
reinterpret_cast<StrideOutput*>(stride_buf + (stride_size * 2));
typename GroupedGemmConfigs::StrideInputA* stride_input_A_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputA*>(stride_buf);
typename GroupedGemmConfigs::StrideInputB* stride_input_B_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputB*>(
stride_buf + stride_size);
typename GroupedGemmConfigs::StrideOutput* stride_output_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideOutput*>(
stride_buf + (stride_size * 2));

for (int i = 0; i < problem_count; ++i) {
const int64_t output_size = XQ[i].size(0) * WQ[i].size(0);
Expand All @@ -206,20 +213,43 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(

xq_ptr[i] = reinterpret_cast<int64_t>(XQ[i].data_ptr<at::Float8_e4m3fn>());
wq_ptr[i] = reinterpret_cast<int64_t>(WQ[i].data_ptr<at::Float8_e4m3fn>());
scale_ptr[i] =
reinterpret_cast<int64_t>(scale[i].data_ptr<ElementAccumulator>());
problem_shape_ptr[i] = ProblemShape::UnderlyingProblemShape(m, n, k);
stride_input_A_ptr[i] =
cutlass::make_cute_packed_stride(StrideInputA{}, {m, k, 1});
stride_input_B_ptr[i] =
cutlass::make_cute_packed_stride(StrideInputB{}, {n, k, 1});
stride_output_ptr[i] =
cutlass::make_cute_packed_stride(StrideOutput{}, {m, n, 1});
scale_ptr[i] = reinterpret_cast<int64_t>(
scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>());
problem_shape_ptr[i] =
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(m, n, k);
stride_input_A_ptr[i] = cutlass::make_cute_packed_stride(
typename GroupedGemmConfigs::StrideInputA{}, {m, k, 1});
stride_input_B_ptr[i] = cutlass::make_cute_packed_stride(
typename GroupedGemmConfigs::StrideInputB{}, {n, k, 1});
stride_output_ptr[i] = cutlass::make_cute_packed_stride(
typename GroupedGemmConfigs::StrideOutput{}, {m, n, 1});
}

const auto device = XQ[0].device();
input_args = input_args.to(device, /*non_blocking=*/true);
output_args = output_args.to(device, /*non_blocking=*/true);
// Allocate input args memory on the GPU
size_t input_args_size = input_args.numel() * sizeof(int64_t);
at::Tensor d_input_args = at::empty(
{problem_count * 3 + problem_shape_size + stride_size * 3},
at::TensorOptions().dtype(at::kLong).device(at::kCUDA));

// Allocate output args memory on the GPU
size_t output_args_size = output_args.numel() * sizeof(int64_t);
at::Tensor d_output_args = at::empty(
{problem_count}, at::TensorOptions().dtype(at::kLong).device(at::kCUDA));

// Copy data from CPU to GPU asynchronously
cudaMemcpyAsync(
d_input_args.data_ptr(),
input_args.data_ptr<int64_t>(),
input_args_size,
cudaMemcpyHostToDevice,
at::cuda::getCurrentCUDAStream());

cudaMemcpyAsync(
d_output_args.data_ptr(),
output_args.data_ptr<int64_t>(),
output_args_size,
cudaMemcpyHostToDevice,
at::cuda::getCurrentCUDAStream());

output_ptr = output_args.data_ptr<int64_t>();
xq_ptr = input_args.data_ptr<int64_t>();
Expand All @@ -230,41 +260,45 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
problem_shape_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)));
problem_shape_ptr =
reinterpret_cast<typename ProblemShape::UnderlyingProblemShape*>(
reinterpret_cast<GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
stride_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)) +
problem_shape_size);
stride_input_A_ptr = reinterpret_cast<StrideInputA*>(stride_buf);
stride_input_A_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputA*>(stride_buf);
stride_input_B_ptr =
reinterpret_cast<StrideInputB*>(stride_buf + stride_size);
reinterpret_cast<typename GroupedGemmConfigs::StrideInputB*>(
stride_buf + stride_size);
stride_output_ptr =
reinterpret_cast<StrideOutput*>(stride_buf + (stride_size * 2));
reinterpret_cast<typename GroupedGemmConfigs::StrideOutput*>(
stride_buf + (stride_size * 2));

typename Gemm::Arguments arguments;
typename GroupedGemmConfigs::Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha_ptr_array =
reinterpret_cast<const ElementAccumulator**>(scale_ptr);
reinterpret_cast<const GroupedGemmArgs::ElementAccumulator**>(scale_ptr);
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};

arguments = typename Gemm::Arguments{
arguments = typename GroupedGemmConfigs::Gemm::Arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{problem_count, problem_shape_ptr, nullptr},
{reinterpret_cast<const ElementInputA**>(xq_ptr),
{reinterpret_cast<const GroupedGemmArgs::ElementInputA**>(xq_ptr),
stride_input_A_ptr,
reinterpret_cast<const ElementInputB**>(wq_ptr),
reinterpret_cast<const GroupedGemmArgs::ElementInputB**>(wq_ptr),
stride_input_B_ptr},
{fusion_args,
reinterpret_cast<const ElementOutput**>(output_ptr),
reinterpret_cast<const GroupedGemmArgs::ElementOutput**>(output_ptr),
stride_output_ptr,
reinterpret_cast<ElementOutput**>(output_ptr),
reinterpret_cast<GroupedGemmArgs::ElementOutput**>(output_ptr),
stride_output_ptr}};

Gemm gemm;
typename GroupedGemmConfigs::Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size =
GroupedGemmConfigs::Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Expand Down

0 comments on commit 9b4b04b

Please sign in to comment.