Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] Support wint4 groupwise with cutlass gemm #60422

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 81 additions & 12 deletions paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* http://www.apache.org/licenses/LICENSE-2.0
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/

/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Expand Down Expand Up @@ -42,5 +58,58 @@ namespace arch {
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;

/*
Below we have extra tags to signal what kind of dequantization we want to do
(per col, scale only fine grained, finegrained with zero). This still lets us
the existing template infrastructure (incl. that in CUTLASS). However, we
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.

Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount
of code we need to duplicate.
*/
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale;

// The default just forwards the original operator
template <typename MmaOp, bool FineGrained>
struct TagOperator {
using TaggedOperator = MmaOp;
};

// Specializations below attach more information to the operator
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, false> {
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
};

template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, true> {
using TaggedOperator =
OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale;
};

// Here we instantiate some structs to "detag" the tagged operator. It splits it
// back to the original operator + the extra information. If no extra info was
// tagged, the dequant op per column scaling as a default.
template <typename TaggedMmaOp>
struct DetagOperator {
using Operator = TaggedMmaOp;
static constexpr bool FineGrained = false;
};

template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale> {
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr bool FineGrained = false;
};

template <>
struct DetagOperator<
OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale> {
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr bool FineGrained = true;
};

} // namespace arch
} // namespace cutlass
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ enum class CutlassTileConfig {

// configs for large M in encoder
CtaShape128x256x64_WarpShape64x64x64,
// CtaShape256x128x64_WarpShape64x64x64

// configs for finegrained
CtaShape256x128x64_WarpShape64x64x64,
};

enum class SplitKStyle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename KernelArch, ///! The Architecture this kernel is compiled
/// for. Used since SIMT kernels lose top-level
/// arch.
bool SplitKSerial ///! If true, code supporting split-K via serial
bool SplitKSerial, ///! If true, code supporting split-K via serial
/// reduction is enabled.
bool Finegrained ///! If true, finegrained mode is enabled.
/// Currently only support groupwise.
>
struct GemmFpAIntB {
using Mma = Mma_;
Expand Down Expand Up @@ -103,6 +105,7 @@ struct GemmFpAIntB {
/// Parameters structure
struct Arguments : UniversalArgumentsBase {
cutlass::gemm::GemmCoord problem_size;
int group_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
Expand All @@ -125,6 +128,7 @@ struct GemmFpAIntB {

CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size,
int group_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale,
Expand All @@ -143,6 +147,7 @@ struct GemmFpAIntB {
problem_size,
/*serial_split_k_factor=*/serial_split_k_factor,
/*batch_stride_D=*/0),
group_size(group_size),
ref_A(ref_A),
ref_B(ref_B),
ref_scale(ref_scale),
Expand Down Expand Up @@ -181,6 +186,7 @@ struct GemmFpAIntB {
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
int group_size;

//
// Methods
Expand All @@ -192,6 +198,7 @@ struct GemmFpAIntB {
CUTLASS_HOST_DEVICE
Params(Arguments const& args, int device_sms, int sm_occupancy)
: ParamsBase(args, device_sms, sm_occupancy),
group_size(args.group_size),
params_A(args.ref_A.layout()),
ref_A(args.ref_A),
params_B(args.ref_B.layout()),
Expand Down Expand Up @@ -276,6 +283,52 @@ struct GemmFpAIntB {
return Status::kSuccess;
}

// Initializes the fine grained scale+bias iterator. Needed since the fine
// grained iterator has a different constructor signature than a regular
// cutlass iterator

template <typename IteratorScale, bool FineGrained>
struct initialize_scale {
CUTLASS_DEVICE static IteratorScale apply(
typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale,
typename IteratorScale::TensorCoord extent,
int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset,
int group_size);
};

template <typename IteratorScale>
struct initialize_scale<IteratorScale, true> {
CUTLASS_DEVICE static IteratorScale apply(
typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale,
typename IteratorScale::TensorCoord extent,
int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset,
int group_size) {
return IteratorScale(params,
pointer_scale,
extent,
thread_id,
threadblock_offset,
group_size);
}
};

template <typename IteratorScale>
struct initialize_scale<IteratorScale, false> {
CUTLASS_DEVICE static IteratorScale apply(
typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale,
typename IteratorScale::TensorCoord extent,
int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset,
int group_size) {
return IteratorScale(
params, pointer_scale, extent, thread_id, threadblock_offset);
}
};
static size_t get_extra_workspace_size(
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
return 0;
Expand Down Expand Up @@ -335,8 +388,12 @@ struct GemmFpAIntB {
threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};

typename MatrixCoord::Index fg_row_offset =
threadblock_tile_offset.k() * params.gemm_k_size / 64;
typename MatrixCoord::Index scale_row_offset =
Finegrained == true ? fg_row_offset : 0;
cutlass::MatrixCoord tb_offset_scale{
0, threadblock_tile_offset.n() * Mma::Shape::kN};
scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};

// Problem size is a function of threadblock index in the K dimension
int problem_size_k =
Expand Down Expand Up @@ -368,11 +425,16 @@ struct GemmFpAIntB {
tb_offset_B,
params.gather_B_indices);

typename Mma::IteratorScale iterator_scale(params.params_scale,
params.ref_scale.data(),
{1, params.problem_size.n()},
thread_idx,
tb_offset_scale);
typename MatrixCoord::Index scale_row_extent =
Finegrained == true ? problem_size_k / 64 : 1;
typename Mma::IteratorScale iterator_scale =
initialize_scale<typename Mma::IteratorScale, Finegrained>::apply(
params.params_scale,
params.ref_scale.data(),
{scale_row_extent, params.problem_size.n()},
thread_idx,
tb_offset_scale,
params.group_size);

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
Expand All @@ -383,7 +445,11 @@ struct GemmFpAIntB {
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
Mma mma(shared_storage.main_loop,
params.group_size,
thread_idx,
warp_idx,
lane_idx);

typename Mma::FragmentC accumulators;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ struct GemmFpAIntBSplitK {
// static_assert(print_type<Mma::>());

// Perform this tile's range of multiply-accumulate (MAC) iterations
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx);

mma(tile_work.k_iters_remaining,
accumulator_tile,
Expand Down
Loading