diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h index 151153e4b297d..aef10063f1dc4 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h @@ -1,18 +1,34 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * http://www.apache.org/licenses/LICENSE-2.0 + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. @@ -42,5 +58,58 @@ namespace arch { // Tag which triggers MMA which will trigger struct OpMultiplyAddDequantizeInterleavedBToA; +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount + of code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it +// back to the original operator + the extra information. If no extra info was +// tagged, the dequant op per column scaling as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr bool FineGrained = false; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = false; +}; + +template <> +struct DetagOperator< + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale> { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = true; +}; + } // namespace arch } // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h index 68bf13bb25995..972c9e1ffa628 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h @@ -61,7 +61,9 @@ enum class CutlassTileConfig { // configs for large M in encoder CtaShape128x256x64_WarpShape64x64x64, - // CtaShape256x128x64_WarpShape64x64x64 + + // configs for finegrained + CtaShape256x128x64_WarpShape64x64x64, }; enum class SplitKStyle { diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index a14604728baf4..839745161a3d8 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -56,8 +56,10 @@ template struct GemmFpAIntB { using Mma = Mma_; @@ -103,6 +105,7 @@ struct GemmFpAIntB { /// Parameters structure struct Arguments : UniversalArgumentsBase { cutlass::gemm::GemmCoord problem_size; + int group_size; typename Mma::IteratorA::TensorRef ref_A; typename Mma::IteratorB::TensorRef ref_B; typename Mma::IteratorScale::TensorRef ref_scale; @@ -125,6 +128,7 @@ struct GemmFpAIntB { CUTLASS_HOST_DEVICE Arguments(cutlass::gemm::GemmCoord const& problem_size, + int group_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Mma::IteratorScale::TensorRef ref_scale, @@ -143,6 +147,7 @@ struct GemmFpAIntB { problem_size, /*serial_split_k_factor=*/serial_split_k_factor, /*batch_stride_D=*/0), + group_size(group_size), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), @@ -181,6 +186,7 @@ struct GemmFpAIntB { int const* gather_A_indices; int const* gather_B_indices; int const* scatter_D_indices; + int group_size; // // Methods @@ -192,6 +198,7 @@ struct GemmFpAIntB { CUTLASS_HOST_DEVICE Params(Arguments const& args, int device_sms, int sm_occupancy) : ParamsBase(args, device_sms, sm_occupancy), + group_size(args.group_size), params_A(args.ref_A.layout()), ref_A(args.ref_A), params_B(args.ref_B.layout()), @@ -276,6 +283,52 @@ struct GemmFpAIntB { return Status::kSuccess; } + // Initializes the fine grained scale+bias iterator. Needed since the fine + // grained iterator has a different constructor signature than a regular + // cutlass iterator + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size); + }; + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, + pointer_scale, + extent, + thread_id, + threadblock_offset, + group_size); + } + }; + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale( + params, pointer_scale, extent, thread_id, threadblock_offset); + } + }; static size_t get_extra_workspace_size( Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { return 0; @@ -335,8 +388,12 @@ struct GemmFpAIntB { threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + typename MatrixCoord::Index fg_row_offset = + threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = + Finegrained == true ? fg_row_offset : 0; cutlass::MatrixCoord tb_offset_scale{ - 0, threadblock_tile_offset.n() * Mma::Shape::kN}; + scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; // Problem size is a function of threadblock index in the K dimension int problem_size_k = @@ -368,11 +425,16 @@ struct GemmFpAIntB { tb_offset_B, params.gather_B_indices); - typename Mma::IteratorScale iterator_scale(params.params_scale, - params.ref_scale.data(), - {1, params.problem_size.n()}, - thread_idx, - tb_offset_scale); + typename MatrixCoord::Index scale_row_extent = + Finegrained == true ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = + initialize_scale::apply( + params.params_scale, + params.ref_scale.data(), + {scale_row_extent, params.problem_size.n()}, + thread_idx, + tb_offset_scale, + params.group_size); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -383,7 +445,11 @@ struct GemmFpAIntB { // Main loop // // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + Mma mma(shared_storage.main_loop, + params.group_size, + thread_idx, + warp_idx, + lane_idx); typename Mma::FragmentC accumulators; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h index a4fda93533a1f..b4f1039579861 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h @@ -847,7 +847,7 @@ struct GemmFpAIntBSplitK { // static_assert(print_type()); // Perform this tile's range of multiply-accumulate (MAC) iterations - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); mma(tile_work.k_iters_remaining, accumulator_tile, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index da4b8d73376f6..e27a9e8ee9f84 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h" @@ -46,6 +47,54 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// +template +struct DefaultScaleIterators; + +// Fine grained iterators +template +struct DefaultScaleIterators { + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIterators { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaShape::kN / Alignment, + Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + template < /// Type for elementA typename ElementA, @@ -80,7 +129,7 @@ template < /// Stages in GEMM int kStages, /// - typename Operator, + typename Operator_, /// SharedMemoryClearOption SharedMemoryClear> struct DqMma= 80)>::type> { + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); @@ -171,22 +223,15 @@ struct DqMma; - // ThreadMap for scale iterator static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; + using ScaleIterators = DefaultScaleIterators; // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape<1, MmaCore::Shape::kN>, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - + using IteratorScale = typename ScaleIterators::IteratorScale; using SmemIteratorScale = IteratorScale; using Converter = FastInterleavedAndBiasedNumericArrayConverter< @@ -210,7 +255,8 @@ struct DqMma; + SharedMemoryClear, + OperatorInfo::FineGrained>; }; template < @@ -245,7 +291,7 @@ template < /// Stages in GEMM int kStages, /// - typename Operator, + typename Operator_, /// SharedMemoryClearOption SharedMemoryClear, /// @@ -269,10 +315,13 @@ struct DqMma= 80)>::type> { + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); @@ -364,19 +413,14 @@ struct DqMma, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; + using ScaleIterators = DefaultScaleIterators; // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape<1, MmaCore::Shape::kN>, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; + using IteratorScale = typename ScaleIterators::IteratorScale; using SmemIteratorScale = IteratorScale; @@ -401,7 +445,8 @@ struct DqMma; + SharedMemoryClear, + OperatorInfo::FineGrained>; }; } // namespace threadblock diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h new file mode 100644 index 0000000000000..5b6e8249aa80c --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h @@ -0,0 +1,741 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage : public DqMmaBase { + public: + ///< Base class + using Base = + DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to + /// shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, // NOLINT + /// The group size for quantization + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {Base::kStages, Shape::kN}, + thread_idx, + group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, // NOLINT + int stage = -1, + int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = + iterator_scale.get_scale(); + // typename IteratorScale::AccessType* gmem_zero_ptr = + // iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = + reinterpret_cast( + this->smem_iterator_scale_.get_scale()); + // typename IteratorScale::AccessType* smem_zero_ptr + // = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * + IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async( + smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + // if (gmem_zero_ptr != nullptr) + // { + // cutlass::arch::cp_async(smem_zero_ptr, + // gmem_zero_ptr, iterator_scale.valid()); + // } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT + IteratorB& iterator_B, // NOLINT + IteratorScale& iterator_scale, // NOLINT + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, // NOLINT + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + // typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + iterator_scale, + group_start_iteration_A, + group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for + // the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + iterator_scale, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h new file mode 100644 index 0000000000000..7307131f8dfd3 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h @@ -0,0 +1,684 @@ + +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage + : public DqMmaBase { + public: + ///< Base class + using Base = + DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, // NOLINT + ///< Group size for quantization. Not used by this main loop since it + ///< assumes per-column + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT + IteratorB& iterator_B, // NOLINT + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, // NOLINT + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h index bf95ed2fc3540..b6911f05a4500 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -1,33 +1,34 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * http://www.apache.org/licenses/LICENSE-2.0 + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /*! \file \brief Template for a double-buffered threadblock-scoped GEMM kernel. */ @@ -94,559 +95,12 @@ template < /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization - typename Enable = bool> -class DqMmaMultistage : public DqMmaBase { - public: - ///< Base class - using Base = - DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - // - // Dependent types - // - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / - Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / - Base::kWarpGemmIterations; - }; - - private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave< - typename LayoutDetailsForB::Layout>::value; - static_assert(!RequiresTileInterleave || - (RequiresTileInterleave && - (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - - private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared - /// memory - SmemIteratorScale smem_iterator_scale_; - - public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, // NOLINT - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_dequantizer_( - {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / - Base::WarpCount::kM, - lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_scale_(LayoutScale(Shape::kN), - shared_storage.operand_scale.data(), - {1, Shape::kN}, - thread_idx) { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT - IteratorB& iterator_B, // NOLINT - int group_start_A = 0, - int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, // NOLINT - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) { - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group - // as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for - // some kernels so that all accumulator elements outside the GEMM footprint - // are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - /// Iterator to write threadblock-scoped tile of A operand to shared - /// memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast( - last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared - /// memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast( - last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % - Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - const int warp_tileB_k_compute_offset = - warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = - warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == - Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load( - warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma(warp_mma, - accum, - warp_frag_A[warp_mma_k % 2], - converted_frag_B, - accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, - iterator_B, - group_start_iteration_A, - group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, - iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, - -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterationsForB, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM - // mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// + bool FineGrained = false> +class DqMmaMultistage; } // namespace threadblock } // namespace gemm } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h" diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h index 65fe9693727ea..9071e1affad16 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -195,11 +195,18 @@ class DqMmaPipelined : public DqMmaBase=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp ) : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_( diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index 1426182b1363c..e02e79316c460 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -194,6 +194,13 @@ class MmaTensorOpDequantizer< } } + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; + } + private: ElementScale const* pointer_; }; @@ -297,6 +304,13 @@ class MmaTensorOpDequantizer< } } + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; + } + private: ElementScale const* pointer_; }; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 0000000000000..24c95134cfe29 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,277 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Templates for visiting scales to be used when dequantizing the + weights for weight-only GEMM quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +// #define _DEBUG_CUTLASS_FINE_GRAINED_SCALE_ZERO_ITERATOR + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : stride_(layout.stride(0)) { // NOLINT + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + // BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + // Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), + pointer_scale_(reinterpret_cast( + const_cast(pointer_scale))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / + (group_size / 64) * params_.stride_ * + sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = + threadblock_offset.column() * sizeof_bits::value / 8; + + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + // TODO(freeliuzc): support ZERO + // if (pointer_zero_ != nullptr) + // { + // pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + // } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + const int thread_row = thread_id / THREADS_PER_ROW; + const int thread_col = thread_id % THREADS_PER_ROW; + const LongIndex thread_row_byte_offset = + thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = + thread_col * kAlignment * sizeof_bits::value / 8; + + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + // TODO(freeliuzc): support ZERO + // if (pointer_zero_ != nullptr) + // { + // pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + // } + + // For the rows, we must check that we are within the extent AND the tile to + // avoid extra reads on a given iteration. The same threads will be + // responsible for issues reads since the number of scales read in a given + // iteration is a constant. Therefore, we should never have to update + // is_valid_ outside of the constructor. + const int global_row = threadblock_offset.row() + thread_row; + const int global_col = + threadblock_offset.column() + thread_col * kAlignment; + + const bool row_in_bounds = + global_row < extent.row() && thread_row < Shape::kRow; + const bool col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator( + Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + // Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator(params, + pointer_scale, + extent, + thread_id, + make_Coord(0, 0), + group_size) {} + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = + tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + + // TODO(freeliuzc): support ZERO + // if (pointer_zero_ != nullptr) + // { + // pointer_zero_ += row_byte_offset + col_byte_offset; + // } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return is_valid_; } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + // TODO(freeliuzc): support ZERO + // Returns a zero pointer + // CUTLASS_HOST_DEVICE + // AccessType* get_zero() const + // { + // return reinterpret_cast(pointer_zero_); + // } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 38048a08f9c0d..ff878f896a74d 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -56,6 +56,8 @@ static TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { // {256, 128} have better performance than 128, 128 case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; default: throw std::runtime_error( "[fpA_intB_gemm Error][get_grid_shape_for_config] Invalid config"); @@ -106,7 +108,8 @@ static std::vector get_candidate_tiles( const bool is_weight_only, const bool is_weight_only_encoder, const bool simt_configs_only, - const int sm) { + const int sm, + const int group_size) { VLOG(3) << "get_candidate_tiles sm: " << sm; std::vector simt_configs{ CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; @@ -124,13 +127,23 @@ static std::vector get_candidate_tiles( CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64, CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64}; + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + }; + std::vector quant_B_configs_sm80_finegrained{ + CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + }; std::vector quant_B_configs; switch (sm) { case 86: - case 80: - quant_B_configs = quant_B_configs_sm80; + case 80: { + quant_B_configs = group_size > 0 ? quant_B_configs_sm80_finegrained + : quant_B_configs_sm80; break; + } case 75: case 70: quant_B_configs = quant_B_configs_sm70; @@ -147,12 +160,17 @@ static std::vector get_candidate_tiles( } static std::vector get_candidate_configs( - int sm, + const int sm, + const int group_size, const bool is_weight_only, const bool is_weight_only_encoder, const bool simt_configs_only) { - std::vector tiles = get_candidate_tiles( - is_weight_only, is_weight_only_encoder, simt_configs_only, sm); + std::vector tiles = + get_candidate_tiles(is_weight_only, + is_weight_only_encoder, + simt_configs_only, + sm, + group_size); std::vector candidate_configs; const int min_stages = 2; @@ -174,11 +192,13 @@ static CutlassGemmConfig estimate_best_config_from_occupancies( const int64_t m, const int64_t n, const int64_t k, + const int group_size, const int64_t num_experts, const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, - const int is_weight_only) { + const int is_weight_only, + const int sm) { if (occupancies.size() != candidate_configs.size()) { throw std::runtime_error( "[fpA_intB_gemm Error][estimate_best_config_from_occupancies] " @@ -187,14 +207,41 @@ static CutlassGemmConfig estimate_best_config_from_occupancies( } CutlassGemmConfig best_config; - if (m >= 256 && + + if (m >= 256 && sm == 86 && group_size > 0 && std::find_if( candidate_configs.begin(), candidate_configs.end(), [](const CutlassGemmConfig& gemm_config) { return gemm_config.tile_config == - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64; + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64; }) != candidate_configs.end()) { + best_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + SplitKStyle::NO_SPLIT_K, + 1, + 2}; + } else if (m >= 256 && sm == 80 && group_size > 0 && + std::find_if(candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig:: + CtaShape256x128x64_WarpShape64x64x64; + }) != candidate_configs.end()) { + best_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + SplitKStyle::NO_SPLIT_K, + 1, + 4}; + } else if (m >= 256 && sm == 80 && group_size <= 0 && + std::find_if(candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig:: + CtaShape128x256x64_WarpShape64x64x64; + }) != candidate_configs.end()) { best_config = CutlassGemmConfig{ CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, SplitKStyle::NO_SPLIT_K, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h index 15c5267ae0f9d..0fef3771f2f05 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -63,6 +63,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream); @@ -75,6 +76,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, @@ -84,7 +86,7 @@ class CutlassFpAIntBGemmRunner { int getWorkspaceSize(const int m, const int n, const int k); private: - template + template void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, @@ -93,13 +95,14 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); - template + template void run_gemm(const T* A, const WeightType* B, const T* weight_scales, @@ -108,6 +111,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream); @@ -136,6 +140,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream); @@ -148,6 +153,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu index 2f566d4dbc35e..dce644bd7ae1d 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu @@ -42,6 +42,7 @@ template void dispatch_gemm_config(const T* A, @@ -52,6 +53,7 @@ void dispatch_gemm_config(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -63,6 +65,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 2>; @@ -74,6 +77,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -85,6 +89,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 3>; @@ -96,6 +101,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -107,6 +113,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 4>; @@ -118,6 +125,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -129,6 +137,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 5>; @@ -140,6 +149,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -154,7 +164,11 @@ void dispatch_gemm_config(const T* A, } } -template +template void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, @@ -163,6 +177,7 @@ void dispatch_gemm_to_cutlass(const T* A, int m, int n, int k, + int group_size, char* workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, @@ -179,6 +194,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>( A, @@ -189,6 +205,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -201,6 +218,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>( A, @@ -211,6 +229,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -222,6 +241,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -232,6 +252,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -244,6 +265,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -254,6 +276,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -266,6 +289,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -276,6 +300,30 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, gemm_config, workspace, workspace_bytes, @@ -300,7 +348,11 @@ void dispatch_gemm_to_cutlass(const T* A, } } -template +template void dispatch_gemm_to_cutlass_sm7x(const T* A, const WeightType* B, const T* weight_scales, @@ -309,6 +361,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, int m, int n, int k, + int group_size, char* workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, @@ -324,6 +377,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>( A, @@ -334,6 +388,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -345,6 +400,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -355,6 +411,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -394,8 +451,9 @@ CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { } template -template -void CutlassFpAIntBGemmRunner::dispatch_to_arch( +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch( const T* A, const WeightType* B, const T* weight_scales, @@ -404,6 +462,7 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, @@ -415,19 +474,21 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( dispatch_gemm_to_cutlass_sm7x(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); + EpilogueTag, + false>(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); #else throw std::runtime_error( "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " @@ -438,19 +499,21 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( dispatch_gemm_to_cutlass_sm7x(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); + EpilogueTag, + false>(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); #else throw std::runtime_error( "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " @@ -458,20 +521,24 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( #endif } else if (sm_ >= 80 && sm_ < 90) { #if defined(USE_FPAINTB_GEMM_WITH_SM80) - dispatch_gemm_to_cutlass( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); + dispatch_gemm_to_cutlass(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); #else throw std::runtime_error( "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " @@ -485,8 +552,9 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( } template -template -void CutlassFpAIntBGemmRunner::run_gemm( +template +void CutlassFpAIntBGemmRunner::run_gemm( const T* A, const WeightType* B, const T* weight_scales, @@ -495,30 +563,32 @@ void CutlassFpAIntBGemmRunner::run_gemm( int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { // VLOG(3)<<__PRETTY_FUNCTION__; static constexpr bool is_weight_only = !std::is_same::value; const bool is_weight_only_encoder = m >= 512 ? true : false; - std::vector candidate_configs = - get_candidate_configs(sm_, is_weight_only, is_weight_only_encoder, false); + std::vector candidate_configs = get_candidate_configs( + sm_, group_size, is_weight_only, is_weight_only_encoder, false); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - candidate_configs[ii], - workspace_ptr, - workspace_bytes, - stream, - &occupancies[ii]); + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + candidate_configs[ii], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[ii]); } // Standard GEMM, so 1 "expert". We use the same function for MoE and regular // FFN. @@ -529,24 +599,27 @@ void CutlassFpAIntBGemmRunner::run_gemm( m, n, k, + group_size, num_experts, split_k_limit, workspace_bytes, multi_processor_count_, - is_weight_only); + is_weight_only, + sm_); - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - chosen_config, - workspace_ptr, - workspace_bytes, - stream); + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + chosen_config, + workspace_ptr, + workspace_bytes, + stream); } template @@ -559,6 +632,7 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, @@ -570,17 +644,37 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( PADDLE_THROW(phi::errors::Unimplemented( "Activation_type = relu for fpA_intB gemm is not instantiated.")); } else if (activation_type == "none") { - run_gemm(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); + if (group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } else { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } } else { throw std::runtime_error(("Invalid activation type.")); } @@ -594,21 +688,41 @@ void CutlassFpAIntBGemmRunner::gemm(const T* A, int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { - // VLOG(3)<<__PRETTY_FUNCTION__; - run_gemm(A, - B, - weight_scales, - nullptr, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); + if (group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + nullptr, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } else { + run_gemm(A, + B, + weight_scales, + nullptr, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } } template @@ -636,6 +750,7 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, @@ -654,6 +769,7 @@ void CutlassFpAIntBGemmRunner::gemm( int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 8ae1047c43afc..f7c73dc99cede 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -54,6 +54,7 @@ template @@ -65,6 +66,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -117,7 +119,13 @@ void generic_mixed_gemm_kernelLauncher(const T* A, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op; - if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K) { + + if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K || + FineGrained == true) { + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = + typename cutlass::arch::TagOperator::TaggedOperator; using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< ElementType, cutlass::layout::RowMajor, @@ -137,14 +145,15 @@ void generic_mixed_gemm_kernelLauncher(const T* A, typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, Stages, true, - typename MixedGemmArchTraits::Operator>::GemmKernel; + TaggedOperator>::GemmKernel; using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< typename GemmKernel_::Mma, typename GemmKernel_::Epilogue, typename GemmKernel_::ThreadblockSwizzle, arch, // Ensure top level arch is used for dispatch - GemmKernel_::kSplitKSerial>; + GemmKernel_::kSplitKSerial, + FineGrained>; if (occupancy != nullptr) { *occupancy = compute_occupancy_for_kernel(); @@ -161,9 +170,10 @@ void generic_mixed_gemm_kernelLauncher(const T* A, typename Gemm::Arguments args( {m, n, k}, + group_size, {reinterpret_cast(const_cast(A)), k}, {reinterpret_cast(const_cast(B)), ldb}, - {reinterpret_cast(const_cast(weight_scales)), 0}, + {reinterpret_cast(const_cast(weight_scales)), n}, {reinterpret_cast(const_cast(biases)), 0}, {reinterpret_cast(C), n}, gemm_config.split_k_factor, @@ -221,7 +231,8 @@ void generic_mixed_gemm_kernelLauncher(const T* A, std::string(cutlassGetStatusString(run_status)); throw std::runtime_error("[fpA_intB Runner] " + err_msg); } - } else { + + } else /* Per-Channel mode */ { // for stream-k, we set gemm_config.split_k_factor = 1 to use default load // balance. gemm_config.split_k_factor = 1; @@ -334,6 +345,7 @@ template @@ -345,6 +357,7 @@ void generic_mixed_gemm_kernelLauncher_template(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -355,6 +368,7 @@ template struct dispatch_stages { @@ -401,6 +418,7 @@ struct dispatch_stages(A, @@ -422,6 +441,7 @@ struct dispatch_stages @@ -441,6 +462,7 @@ struct dispatch_stages(A, @@ -472,6 +496,7 @@ struct dispatch_stages void dispatch_gemm_config(const T* A, @@ -495,13 +521,18 @@ void dispatch_gemm_config(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy); -template +template void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, @@ -510,6 +541,7 @@ void dispatch_gemm_to_cutlass(const T* A, int m, int n, int k, + int group_size, char* workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py index ad7f1e65591ce..5847956020ceb 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py @@ -33,6 +33,7 @@ {WeightType}, {arch}, {EpilogueTag}, + {FineGrained}, {ThreadblockShape}, {WarpShape}, {Stages}>( @@ -44,6 +45,7 @@ int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -53,6 +55,7 @@ {WeightType}, {arch}, {EpilogueTag}, + {FineGrained}, {ThreadblockShape}, {WarpShape}, {Stages}>( @@ -64,6 +67,7 @@ m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -87,6 +91,7 @@ "cutlass::gemm::GemmShape<64, 128, 64>", "cutlass::gemm::GemmShape<128, 128, 64>", "cutlass::gemm::GemmShape<128, 256, 64>", + "cutlass::gemm::GemmShape<256, 128, 64>", ] WarpShapes = [ "cutlass::gemm::GemmShape<16, 32, 64>", @@ -94,6 +99,7 @@ "cutlass::gemm::GemmShape<64, 64, 64>", "cutlass::gemm::GemmShape<64, 64, 64>", "cutlass::gemm::GemmShape<64, 64, 64>", + "cutlass::gemm::GemmShape<64, 64, 64>", ] ThreadblockShapes_sm70 = [ @@ -119,6 +125,9 @@ # "biasReLU": "EpilogueOpBiasReLU", } +FineGrainedTypes = ["true", "false"] +FineGrainedTypes_sm70 = ["false"] + def SubstituteTemplate(template, values): text = template @@ -174,28 +183,36 @@ def parse_args(): # generate source cu def generate_source_cu( - element_type: str, arch: int, epilogue_tag: str, stages: int + element_type: str, + arch: int, + epilogue_tag: str, + stages: int, ): all_code = CommonHead ThreadblockShapes_arch = ThreadblockShapes WarpShapes_arch = WarpShapes + FineGrainedTypes_arch = FineGrainedTypes + if arch < 80: ThreadblockShapes_arch = ThreadblockShapes_sm70 WarpShapes_arch = WarpShapes_sm70 + FineGrainedTypes_arch = FineGrainedTypes_sm70 for WeightType in WeightTypes: for i in range(len(ThreadblockShapes_arch)): - value_dict = { - "T": ElementTypes[element_type], - "WeightType": WeightType, - "arch": Archs[arch], - "EpilogueTag": EpilogueTags[epilogue_tag], - "ThreadblockShape": ThreadblockShapes_arch[i], - "WarpShape": WarpShapes_arch[i], - "Stages": str(stages), - } - all_code += SubstituteTemplate( - DispatchGemmConfigInstanceDeclare, value_dict - ) + for j in range(len(FineGrainedTypes_arch)): + value_dict = { + "T": ElementTypes[element_type], + "WeightType": WeightType, + "arch": Archs[arch], + "EpilogueTag": EpilogueTags[epilogue_tag], + "FineGrained": FineGrainedTypes_arch[j], + "ThreadblockShape": ThreadblockShapes_arch[i], + "WarpShape": WarpShapes_arch[i], + "Stages": str(stages), + } + all_code += SubstituteTemplate( + DispatchGemmConfigInstanceDeclare, value_dict + ) all_code += CommonTail return all_code @@ -221,7 +238,10 @@ def generate_source_cu( element_type, arch, stages, epilogue_tag ) all_code = generate_source_cu( - element_type, arch, epilogue_tag, stages + element_type, + arch, + epilogue_tag, + stages, ) with open(file_name, "w") as f: f.write(all_code) diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index c41b86148291d..901a291d3924d 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -89,6 +89,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, "none", mixgemm_workspace_data, mixgemm_workspace_size_bytes, @@ -104,6 +105,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, mixgemm_workspace_data, mixgemm_workspace_size_bytes, dev_ctx.stream()); @@ -134,6 +136,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, "none", mixgemm_workspace_data, mixgemm_workspace_size_bytes, @@ -149,6 +152,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, mixgemm_workspace_data, mixgemm_workspace_size_bytes, dev_ctx.stream()); diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py index f3749d0b4fb15..f09698e4a1a68 100644 --- a/test/quantization/test_weight_only_linear.py +++ b/test/quantization/test_weight_only_linear.py @@ -109,7 +109,7 @@ def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): def setUp(self): self.config() if self.dtype == "bfloat16" or self.weight_dtype == "int4": - self.atol = 1.5e-1 + self.atol = 1.3e-1 x = np.random.random((self.batch, self.token, self.in_features)) self.x = paddle.to_tensor(x, dtype=self.dtype) if self.bias: @@ -451,8 +451,10 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul groupwise mode need CUDA >= 11.2 and CUDA_ARCH >= 8", ) class WeightOnlyLinearTestCase17(WeightOnlyLinearTestCase): def config(self): @@ -466,8 +468,10 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul groupwise mode need CUDA >= 11.2 and CUDA_ARCH >= 8", ) class WeightOnlyLinearTestCase18(WeightOnlyLinearTestCase): def config(self): @@ -576,6 +580,78 @@ def config(self): self.out_features = 288 +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase25(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase26(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.group_size = 64 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase27(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase28(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.token = 300 + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase29(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.token = 300 + self.group_size = 128 + + @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",