Skip to content

Commit

Permalink
Add int4b_t/uint4b_t support for mixed dtypes GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandar Samardžić committed Jan 12, 2024
1 parent acba5be commit 3bcd5b0
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 15 deletions.
3 changes: 2 additions & 1 deletion include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ struct DefaultMmaTensorOp<
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32,
/// or F16 <= F16 x S4 + F16, F16 <= BF16 x S4 + F32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
Expand Down
168 changes: 158 additions & 10 deletions include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,117 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
return result;
}

};
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 4b (S4/U4)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 4)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;

static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;

using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;

private:
int src_lane_0_, src_lane_1_;
uint32_t byte_selector_0_, byte_selector_10_, byte_selector_11_;
int dst_incr_0_, dst_incr_1_;

public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
int mul;

src_lane_0_ = lane_id ^ 1;
mul = lane_id & 1;
byte_selector_0_ = mul * 0x3715 + (1 - mul) * 0x6240;

src_lane_1_ = lane_id ^ 2;
mul = (lane_id & 2) >> 1;
byte_selector_10_ = mul * 0x7632 + (1 - mul) * 0x5410;
byte_selector_11_ = mul * 0x5410 + (1 - mul) * 0x7632;
dst_incr_0_ = mul * (WarpFragment::kElements / 16);
dst_incr_1_ = (1 - mul) * (WarpFragment::kElements / 16);
}

CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {

WarpFragment result;

MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);

uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[0]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[0]);

// The code assumes that twice more values than needed for a
// F16/BF16 MMA is loaded along contiguous dimension. E.g. in the
// case of column major matrix: threads 0-3 would hold 32 elements
// of the first column in the warp fragment, threads 0-4 32
// elements of the second column, etc.; but only the first 16
// elements of each column will be used for the first MMA
// operation, and the last 16 elements will be used for the
// follow-up MMA operation. This code distributes input values
// across threads so that all of the left (in case of row-major
// matrix) or upper (in case of column-major matrix) half of
// values comes first, and then right/lower half of values comes
// second in corresponding warp fragments. The values are also
// re-distributed between threads so that each value belongs to
// the proper thread for F16/BF16 MMA that will take place after
// the up-casting.

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < WarpFragment::kElements / 16; n++) {
// Exchange values with a neighboring thread, that loaded values
// from the same half of all values as given thread; then,
// combine values in such a way that final values could be
// produced after another exchange.
uint32_t tmp0 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n], src_lane_0_);
uint32_t tmp1 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n + 1], src_lane_0_);
tmp0 = __byte_perm(src_ptr[2 * n], tmp0, byte_selector_0_);
tmp1 = __byte_perm(src_ptr[2 * n + 1], tmp1, byte_selector_0_);

// Exchange values with corresponding thread from the same
// quadruple as given thread, but that loaded values from the
// other half of all values. Then, combine values to produce
// final values hold by given thread.
uint32_t mine = __byte_perm(tmp0, tmp1, byte_selector_10_);
uint32_t theirs = __byte_perm(tmp0, tmp1, byte_selector_11_);
theirs = __shfl_sync(0xFFFFFFFF, theirs, src_lane_1_);
dst_ptr[n + dst_incr_0_] = mine;
dst_ptr[n + dst_incr_1_] = theirs;
}

return result;
}

};

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -412,10 +523,21 @@ class MmaMixedInputTensorOp {

public:

// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionM =
(sizeof_bits<ElementB>::value > sizeof_bits<ElementB>::value)
? 8 * sizeof_bits<ElementB>::value / sizeof_bits<ElementA>::value
: InstructionShape::kM;

// Shape for loading data type from shared memory, accounting
// eventually for narrower ElementA.
using LoadInstructionShapeA =
GemmShape<LoadInstructionM, InstructionShape::kN, InstructionShape::kK>;

/// Iterates over the A operand in Shared Memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
MatrixShape<LoadInstructionShapeA::kM, LoadInstructionShapeA::kK>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

/// Storage for A tile in registers (loaded from Shared Memory)
Expand All @@ -428,10 +550,21 @@ class MmaMixedInputTensorOp {
/// Underlying arch::Mma instruction operand fragement for matrix A
using MmaOperandA = typename ArchMmaOperator::FragmentA;

// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK =
(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value)
? 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value
: InstructionShape::kK;

// Shape for loading data type from shared memory, accounting
// eventually for narrower ElementB.
using LoadInstructionShapeB =
GemmShape<InstructionShape::kM, InstructionShape::kN, LoadInstructionK>;

/// Iterates over the B operand in Shared Memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
MatrixShape<LoadInstructionShapeB::kK, LoadInstructionShapeB::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

/// Storage for B tile in registers (loaded from Shared Memory)
Expand Down Expand Up @@ -492,6 +625,13 @@ class MmaMixedInputTensorOp {
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);

if constexpr (is_B_4bit) {
if (!transform_B_flag_) {
ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements;
}
transform_B_flag_ = !transform_B_flag_;
}

CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {

Expand Down Expand Up @@ -523,14 +663,17 @@ class MmaMixedInputTensorOp {
FragmentA const &A, FragmentB const &B) const {

// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementB, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);

// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<MmaElementB, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
if (transform_B_flag_) {
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementB, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);

// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<MmaElementB, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
}

FragmentA tmp_A;

Expand All @@ -553,6 +696,11 @@ class MmaMixedInputTensorOp {

ptr_dst_A[1] = convert_A(ptr_tmp_A[1]);
}

private:
static constexpr bool is_B_4bit = cutlass::sizeof_bits<ElementB>::value == 4;
static_assert(!is_B_4bit || FragmentB::kElements % 16 == 0);
mutable bool transform_B_flag_ = true ;
};

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
23 changes: 23 additions & 0 deletions test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,27 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1
.run();
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * I4 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = cutlass::int4b_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}

#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
27 changes: 23 additions & 4 deletions test/unit/gemm/warp/testbed.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ struct Testbed {
tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false);

}

/// Returns true if the CUDA device is sufficient to execute the kernel.
Expand Down Expand Up @@ -669,24 +670,42 @@ __global__ void kernel_transform(

Mma mma;

constexpr size_t sizeof_bits_A =
cutlass::sizeof_bits<typename Mma::ElementA>::value;
constexpr size_t sizeof_bits_B =
cutlass::sizeof_bits<typename Mma::ElementB>::value;
constexpr bool is_mixed_and_B_4bit =
(sizeof_bits_A != sizeof_bits_B) && (sizeof_bits_B == 4);
static_assert(!is_mixed_and_B_4bit || FragmentB::kElements % 8 == 0);

accum.clear();

CUTLASS_PRAGMA_NO_UNROLL
for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled

static constexpr int kIncr =
(is_mixed_and_B_4bit ? 2 : 1) * Mma::Policy::MmaShape::kK;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < ThreadblockShape::kK;
k += Mma::Policy::MmaShape::kK) {
for (int k = 0; k < ThreadblockShape::kK; k += kIncr) {
iter_A.load(loaded_frag_A);
iter_B.load(loaded_frag_B);

++iter_A;
iter_B.load(loaded_frag_B);
++iter_B;

mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A,
loaded_frag_B);

mma(accum, transformed_frag_A, transformed_frag_B, accum);

if constexpr (is_mixed_and_B_4bit) {
iter_A.load(loaded_frag_A);
++iter_A;

mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A,
loaded_frag_B);

mma(accum, transformed_frag_A, transformed_frag_B, accum);
}
}
}

Expand Down

0 comments on commit 3bcd5b0

Please sign in to comment.