-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add int4b_t/uint4b_t support for mixed dtypes GEMM #1190
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -266,6 +266,117 @@ struct FragmentShuffler <ElementMma_, ElementLoad_, | |
return result; | ||
} | ||
|
||
}; | ||
//////////////////////////////////////////////////////////////////////////////// | ||
|
||
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 4b (S4/U4) | ||
/// for operand B multiplicand going through upcasting. | ||
template < | ||
/// Element type for the operand in registers for the mma.sync | ||
typename ElementMma_, | ||
/// Element type for the operand in shared memory for ldmatrix | ||
typename ElementLoad_, | ||
/// Number of mma.sync operations performed along rows or columns | ||
int NumMmaInstructions, | ||
/// Number of elements in warp fragment | ||
int NumElementsInWarpFragment, | ||
/// Number of elements in mma fragment | ||
int NumElementsInMmaFragment | ||
> | ||
struct FragmentShuffler <ElementMma_, ElementLoad_, | ||
NumMmaInstructions, | ||
NumElementsInWarpFragment, | ||
NumElementsInMmaFragment, | ||
Operand::kB, | ||
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) && | ||
(sizeof_bits<ElementLoad_>::value == 4)>::type> { | ||
public: | ||
using ElementMma = ElementMma_; | ||
using ElementLoad = ElementLoad_; | ||
|
||
static int const kNumMmaInstructions = NumMmaInstructions; | ||
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; | ||
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; | ||
static Operand const kOperand = Operand::kB; | ||
|
||
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>; | ||
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>; | ||
|
||
private: | ||
int src_lane_0_, src_lane_1_; | ||
uint32_t byte_selector_0_, byte_selector_10_, byte_selector_11_; | ||
int dst_incr_0_, dst_incr_1_; | ||
|
||
public: | ||
CUTLASS_DEVICE | ||
FragmentShuffler() { | ||
int lane_id = cutlass::arch::LaneId(); | ||
int mul; | ||
|
||
src_lane_0_ = lane_id ^ 1; | ||
mul = lane_id & 1; | ||
byte_selector_0_ = mul * 0x3715 + (1 - mul) * 0x6240; | ||
|
||
src_lane_1_ = lane_id ^ 2; | ||
mul = (lane_id & 2) >> 1; | ||
byte_selector_10_ = mul * 0x7632 + (1 - mul) * 0x5410; | ||
byte_selector_11_ = mul * 0x5410 + (1 - mul) * 0x7632; | ||
dst_incr_0_ = mul * (WarpFragment::kElements / 16); | ||
dst_incr_1_ = (1 - mul) * (WarpFragment::kElements / 16); | ||
} | ||
|
||
CUTLASS_DEVICE | ||
WarpFragment operator()(WarpFragment const &src) { | ||
|
||
WarpFragment result; | ||
|
||
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src); | ||
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result); | ||
|
||
uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[0]); | ||
uint32_t* dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[0]); | ||
|
||
// The code assumes that twice more values than needed for a | ||
// F16/BF16 MMA is loaded along contiguous dimension. E.g. in the | ||
// case of column major matrix: threads 0-3 would hold 32 elements | ||
// of the first column in the warp fragment, threads 0-4 32 | ||
// elements of the second column, etc.; but only the first 16 | ||
// elements of each column will be used for the first MMA | ||
// operation, and the last 16 elements will be used for the | ||
// follow-up MMA operation. This code distributes input values | ||
// across threads so that all of the left (in case of row-major | ||
// matrix) or upper (in case of column-major matrix) half of | ||
// values comes first, and then right/lower half of values comes | ||
// second in corresponding warp fragments. The values are also | ||
// re-distributed between threads so that each value belongs to | ||
// the proper thread for F16/BF16 MMA that will take place after | ||
// the up-casting. | ||
|
||
CUTLASS_PRAGMA_UNROLL | ||
for (int n = 0; n < WarpFragment::kElements / 16; n++) { | ||
// Exchange values with a neighboring thread, that loaded values | ||
// from the same half of all values as given thread; then, | ||
// combine values in such a way that final values could be | ||
// produced after another exchange. | ||
uint32_t tmp0 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n], src_lane_0_); | ||
uint32_t tmp1 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n + 1], src_lane_0_); | ||
tmp0 = __byte_perm(src_ptr[2 * n], tmp0, byte_selector_0_); | ||
tmp1 = __byte_perm(src_ptr[2 * n + 1], tmp1, byte_selector_0_); | ||
|
||
// Exchange values with corresponding thread from the same | ||
// quadruple as given thread, but that loaded values from the | ||
// other half of all values. Then, combine values to produce | ||
// final values hold by given thread. | ||
uint32_t mine = __byte_perm(tmp0, tmp1, byte_selector_10_); | ||
uint32_t theirs = __byte_perm(tmp0, tmp1, byte_selector_11_); | ||
theirs = __shfl_sync(0xFFFFFFFF, theirs, src_lane_1_); | ||
dst_ptr[n + dst_incr_0_] = mine; | ||
dst_ptr[n + dst_incr_1_] = theirs; | ||
} | ||
|
||
return result; | ||
} | ||
|
||
}; | ||
|
||
//////////////////////////////////////////////////////////////////////////////// | ||
|
@@ -414,10 +525,21 @@ class MmaMixedInputTensorOp { | |
|
||
public: | ||
|
||
// Chosen so we get K=16 for int8 and K=32 for int4. | ||
static constexpr int LoadInstructionM = | ||
(sizeof_bits<ElementB>::value > sizeof_bits<ElementA>::value) | ||
? 8 * sizeof_bits<ElementB>::value / sizeof_bits<ElementA>::value | ||
: InstructionShape::kM; | ||
|
||
// Shape for loading data type from shared memory, accounting | ||
// eventually for narrower ElementA. | ||
using LoadInstructionShapeA = | ||
GemmShape<LoadInstructionM, InstructionShape::kN, InstructionShape::kK>; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are only supporting TN layouts with this. right? I only see adjustment needed for the K-dim (which is going to be contiguous dimension) and not for M-dim or N-dim. The K-dim of the iterator which is on 4bits will need to adjust the load shapeK? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to eventually support all layout combinations that are supported for int8, so - yes, I'll have to refine this part. |
||
/// Iterates over the A operand in Shared Memory | ||
using IteratorA = MmaTensorOpMultiplicandTileIterator< | ||
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA, | ||
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>, | ||
MatrixShape<LoadInstructionShapeA::kM, LoadInstructionShapeA::kK>, | ||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; | ||
|
||
/// Storage for A tile in registers (loaded from Shared Memory) | ||
|
@@ -430,10 +552,21 @@ class MmaMixedInputTensorOp { | |
/// Underlying arch::Mma instruction operand fragement for matrix A | ||
using MmaOperandA = typename ArchMmaOperator::FragmentA; | ||
|
||
// Chosen so we get K=16 for int8 and K=32 for int4. | ||
static constexpr int LoadInstructionK = | ||
(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value) | ||
? 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value | ||
: InstructionShape::kK; | ||
|
||
// Shape for loading data type from shared memory, accounting | ||
// eventually for narrower ElementB. | ||
using LoadInstructionShapeB = | ||
GemmShape<InstructionShape::kM, InstructionShape::kN, LoadInstructionK>; | ||
|
||
/// Iterates over the B operand in Shared Memory | ||
using IteratorB = MmaTensorOpMultiplicandTileIterator< | ||
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB, | ||
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>, | ||
MatrixShape<LoadInstructionShapeB::kK, LoadInstructionShapeB::kN>, | ||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; | ||
|
||
/// Storage for B tile in registers (loaded from Shared Memory) | ||
|
@@ -494,6 +627,13 @@ class MmaMixedInputTensorOp { | |
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B); | ||
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D); | ||
|
||
if constexpr (is_B_4bit) { | ||
if (!transform_B_flag_) { | ||
ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements; | ||
} | ||
transform_B_flag_ = !transform_B_flag_; | ||
} | ||
|
||
CUTLASS_PRAGMA_UNROLL | ||
for (int m = 0; m < MmaIterations::kRow; ++m) { | ||
|
||
|
@@ -524,15 +664,17 @@ class MmaMixedInputTensorOp { | |
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, | ||
FragmentA const &A, FragmentB const &B) const { | ||
|
||
// Shuffle data within warp to obtain the mma.sync operand layout | ||
detail::FragmentShuffler<ElementBMma, ElementB, MmaIterations::kColumn, | ||
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B; | ||
FragmentB tmp_B; | ||
tmp_B = shuffler_B(B); | ||
if (!is_B_4bit || transform_B_flag_) { | ||
// Shuffle data within warp to obtain the mma.sync operand layout | ||
detail::FragmentShuffler<ElementBMma, ElementB, MmaIterations::kColumn, | ||
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B; | ||
FragmentB tmp_B; | ||
tmp_B = shuffler_B(B); | ||
|
||
// Convert the B operand to the Mma Instruction operand type | ||
detail::FragmentConverter<ElementBMma, ElementB, FragmentB::kElements> convert_B; | ||
dst_B = convert_B(tmp_B); | ||
// Convert the B operand to the Mma Instruction operand type | ||
detail::FragmentConverter<ElementBMma, ElementB, FragmentB::kElements> convert_B; | ||
dst_B = convert_B(tmp_B); | ||
} | ||
|
||
FragmentA tmp_A; | ||
|
||
|
@@ -555,6 +697,11 @@ class MmaMixedInputTensorOp { | |
|
||
ptr_dst_A[1] = convert_A(ptr_tmp_A[1]); | ||
} | ||
|
||
private: | ||
static constexpr bool is_B_4bit = cutlass::sizeof_bits<ElementB>::value == 4; | ||
static_assert(!is_B_4bit || FragmentB::kElements % 16 == 0); | ||
mutable bool transform_B_flag_ = true; | ||
}; | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic, For these mainloop changes, can we run full device level tests to see that nothing is broken on SM80. I would prefer to not touch this file and create a separate version of this file for just
F16 x S4
cases.@hwu36 , What are your thoughts on this? Also, is it possible to handle this outside of mainloop. For e.g. in a specialization of shared memory iterator. We have probably discussed it before, but worth re-visiting that thought. Happy to schedule something between the three of us to brainstorm this PR further.