From ea7888cfbd1541bb7cfe86720f1fc3c2dd1c2810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Fri, 12 Jan 2024 14:34:00 +0100 Subject: [PATCH] Add int4b_t/uint4b_t support for mixed dtypes GEMM --- .../cutlass/gemm/threadblock/mma_multistage.h | 36 ++-- .../gemm/warp/default_mma_tensor_op_sm80.h | 3 +- .../gemm/warp/mma_mixed_input_tensor_op.h | 167 ++++++++++++++++-- test/unit/gemm/device/CMakeLists.txt | 2 + ...s4n_f16t_mixed_input_tensor_op_f16_sm80.cu | 97 ++++++++++ 5 files changed, 283 insertions(+), 22 deletions(-) create mode 100644 test/unit/gemm/device/gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index ef55131707..04d3b9b147 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -181,6 +181,15 @@ class MmaMultistage : /// Pair of B fragments used to overlap shared memory loads and math instructions WarpLoadedFragmentB warp_loaded_frag_B_[2]; WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using ElementA = typename WarpLoadedFragmentA::Element; + using ElementB = typename WarpLoadedFragmentB::Element; + static constexpr size_t sizeof_bits_A = + cutlass::sizeof_bits::value; + static constexpr size_t sizeof_bits_B = + cutlass::sizeof_bits::value; + static constexpr bool is_mixed_and_B_4bit = + (sizeof_bits_A != sizeof_bits_B) && (sizeof_bits_B == 4); }; @@ -254,7 +263,7 @@ class MmaMultistage : if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations / (PipeState::is_mixed_and_B_4bit ? 2 : 1), 0}); smem_read_stage_idx_ = 0; } } @@ -510,17 +519,23 @@ class MmaMultistage : ++this->warp_tile_iterator_A_; // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; + if constexpr (!PipeState::is_mixed_and_B_4bit) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B_; + } else if ((warp_mma_k + 1) % 2 == 0) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k / 2 + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k / 2 + 1) % 2]); + ++this->warp_tile_iterator_B_; + } // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary if (warp_mma_k > 0) { warp_mma_.transform( pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2], pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2]); } // Execute the current warp-tile of MMA operations @@ -528,7 +543,7 @@ class MmaMultistage : warp_mma_( pipe_state.tmp_accum_, pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2], pipe_state.tmp_accum_ ); @@ -541,7 +556,7 @@ class MmaMultistage : warp_mma_( accum, pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2], accum ); } @@ -596,12 +611,11 @@ class MmaMultistage : // the first warp-tile of the next iteration, if necessary (so we can // immediately start issuing MMA instructions at the top of the loop ) if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2], pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2]); } } diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h index 67fcde77e5..a159ac00ef 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -229,7 +229,8 @@ struct DefaultMmaTensorOp< ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial Specialization - inputs are mixed types - uses wider datatype internally. -/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32) +/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32, +/// or F16 <= F16 x S4 + F16, F16 <= BF16 x S4 + F32) template < /// Shape of one matrix production operation (concept: GemmShape) typename WarpShape_, diff --git a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h index f553fbde99..727b7edf03 100644 --- a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -266,6 +266,117 @@ struct FragmentShuffler +struct FragmentShuffler ::value == 16) && + (sizeof_bits::value == 4)>::type> { +public: + using ElementMma = ElementMma_; + using ElementLoad = ElementLoad_; + + static int const kNumMmaInstructions = NumMmaInstructions; + static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; + static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; + static Operand const kOperand = Operand::kB; + + using WarpFragment = Array; + using MmaFragment = Array; + +private: + int src_lane_0_, src_lane_1_; + uint32_t byte_selector_0_, byte_selector_10_, byte_selector_11_; + int dst_incr_0_, dst_incr_1_; + +public: + CUTLASS_DEVICE + FragmentShuffler() { + int lane_id = cutlass::arch::LaneId(); + int mul; + + src_lane_0_ = lane_id ^ 1; + mul = lane_id & 1; + byte_selector_0_ = mul * 0x3715 + (1 - mul) * 0x6240; + + src_lane_1_ = lane_id ^ 2; + mul = (lane_id & 2) >> 1; + byte_selector_10_ = mul * 0x7632 + (1 - mul) * 0x5410; + byte_selector_11_ = mul * 0x5410 + (1 - mul) * 0x7632; + dst_incr_0_ = mul * (WarpFragment::kElements / 16); + dst_incr_1_ = (1 - mul) * (WarpFragment::kElements / 16); + } + + CUTLASS_DEVICE + WarpFragment operator()(WarpFragment const &src) { + + WarpFragment result; + + MmaFragment const* mma_frag_src_ptr = reinterpret_cast(&src); + MmaFragment* mma_frag_dst_ptr = reinterpret_cast(&result); + + uint32_t const* src_ptr = reinterpret_cast(&mma_frag_src_ptr[0]); + uint32_t* dst_ptr = reinterpret_cast(&mma_frag_dst_ptr[0]); + + // The code assumes that twice more values than needed for a + // F16/BF16 MMA is loaded along contiguous dimension. E.g. in the + // case of column major matrix: threads 0-3 would hold 32 elements + // of the first column in the warp fragment, threads 0-4 32 + // elements of the second column, etc.; but only the first 16 + // elements of each column will be used for the first MMA + // operation, and the last 16 elements will be used for the + // follow-up MMA operation. This code distributes input values + // across threads so that all of the left (in case of row-major + // matrix) or upper (in case of column-major matrix) half of + // values comes first, and then right/lower half of values comes + // second in corresponding warp fragments. The values are also + // re-distributed between threads so that each value belongs to + // the proper thread for F16/BF16 MMA that will take place after + // the up-casting. + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < WarpFragment::kElements / 16; n++) { + // Exchange values with a neighboring thread, that loaded values + // from the same half of all values as given thread; then, + // combine values in such a way that final values could be + // produced after another exchange. + uint32_t tmp0 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n], src_lane_0_); + uint32_t tmp1 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n + 1], src_lane_0_); + tmp0 = __byte_perm(src_ptr[2 * n], tmp0, byte_selector_0_); + tmp1 = __byte_perm(src_ptr[2 * n + 1], tmp1, byte_selector_0_); + + // Exchange values with corresponding thread from the same + // quadruple as given thread, but that loaded values from the + // other half of all values. Then, combine values to produce + // final values hold by given thread. + uint32_t mine = __byte_perm(tmp0, tmp1, byte_selector_10_); + uint32_t theirs = __byte_perm(tmp0, tmp1, byte_selector_11_); + theirs = __shfl_sync(0xFFFFFFFF, theirs, src_lane_1_); + dst_ptr[n + dst_incr_0_] = mine; + dst_ptr[n + dst_incr_1_] = theirs; + } + + return result; + } + }; //////////////////////////////////////////////////////////////////////////////// @@ -414,10 +525,21 @@ class MmaMixedInputTensorOp { public: + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionM = + (sizeof_bits::value > sizeof_bits::value) + ? 8 * sizeof_bits::value / sizeof_bits::value + : InstructionShape::kM; + + // Shape for loading data type from shared memory, accounting + // eventually for narrower ElementA. + using LoadInstructionShapeA = + GemmShape; + /// Iterates over the A operand in Shared Memory using IteratorA = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for A tile in registers (loaded from Shared Memory) @@ -430,10 +552,21 @@ class MmaMixedInputTensorOp { /// Underlying arch::Mma instruction operand fragement for matrix A using MmaOperandA = typename ArchMmaOperator::FragmentA; + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = + (sizeof_bits::value > sizeof_bits::value) + ? 8 * sizeof_bits::value / sizeof_bits::value + : InstructionShape::kK; + + // Shape for loading data type from shared memory, accounting + // eventually for narrower ElementB. + using LoadInstructionShapeB = + GemmShape; + /// Iterates over the B operand in Shared Memory using IteratorB = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for B tile in registers (loaded from Shared Memory) @@ -494,6 +627,13 @@ class MmaMixedInputTensorOp { MmaOperandB const *ptr_B = reinterpret_cast(&B); MmaOperandC *ptr_D = reinterpret_cast(&D); + if constexpr (is_B_4bit) { + if (!transform_B_flag_) { + ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements; + } + transform_B_flag_ = !transform_B_flag_; + } + CUTLASS_PRAGMA_UNROLL for (int m = 0; m < MmaIterations::kRow; ++m) { @@ -524,15 +664,17 @@ class MmaMixedInputTensorOp { void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - // Shuffle data within warp to obtain the mma.sync operand layout - detail::FragmentShuffler shuffler_B; - FragmentB tmp_B; - tmp_B = shuffler_B(B); + if (!is_B_4bit || transform_B_flag_) { + // Shuffle data within warp to obtain the mma.sync operand layout + detail::FragmentShuffler shuffler_B; + FragmentB tmp_B; + tmp_B = shuffler_B(B); - // Convert the B operand to the Mma Instruction operand type - detail::FragmentConverter convert_B; - dst_B = convert_B(tmp_B); + // Convert the B operand to the Mma Instruction operand type + detail::FragmentConverter convert_B; + dst_B = convert_B(tmp_B); + } FragmentA tmp_A; @@ -555,6 +697,11 @@ class MmaMixedInputTensorOp { ptr_dst_A[1] = convert_A(ptr_tmp_A[1]); } + +private: + static constexpr bool is_B_4bit = cutlass::sizeof_bits::value == 4; + static_assert(!is_B_4bit || FragmentB::kElements % 16 == 0); + mutable bool transform_B_flag_ = true; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 488c6bfa6f..4d464823eb 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -284,6 +284,8 @@ cutlass_test_unit_add_executable( gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu + + gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 0000000000..7452c746f1 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_s4n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = cutlass::int4b_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 32, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////