Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int4b_t/uint4b_t support for mixed dtypes GEMM #1190

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions include/cutlass/gemm/threadblock/mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ class MmaMultistage :
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];

using ElementA = typename WarpLoadedFragmentA::Element;
using ElementB = typename WarpLoadedFragmentB::Element;
static constexpr size_t sizeof_bits_A =
cutlass::sizeof_bits<ElementA>::value;
static constexpr size_t sizeof_bits_B =
cutlass::sizeof_bits<ElementB>::value;
static constexpr bool is_mixed_and_B_4bit =
(sizeof_bits_A != sizeof_bits_B) && (sizeof_bits_B == 4);
};


Expand Down Expand Up @@ -254,7 +263,7 @@ class MmaMultistage :
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations / (PipeState::is_mixed_and_B_4bit ? 2 : 1), 0});
smem_read_stage_idx_ = 0;
}
}
Expand Down Expand Up @@ -510,25 +519,31 @@ class MmaMultistage :
++this->warp_tile_iterator_A_;

// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
if constexpr (!PipeState::is_mixed_and_B_4bit) {
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
} else if ((warp_mma_k + 1) % 2 == 0) {
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k / 2 + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k / 2 + 1) % 2]);
++this->warp_tile_iterator_B_;
}

// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2]);
}

// Execute the current warp-tile of MMA operations
if (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2],
pipe_state.tmp_accum_
);

Expand All @@ -541,7 +556,7 @@ class MmaMultistage :
warp_mma_(
accum,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2],
accum
);
}
Expand Down Expand Up @@ -596,12 +611,11 @@ class MmaMultistage :
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {

warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic, For these mainloop changes, can we run full device level tests to see that nothing is broken on SM80. I would prefer to not touch this file and create a separate version of this file for just F16 x S4 cases.

@hwu36 , What are your thoughts on this? Also, is it possible to handle this outside of mainloop. For e.g. in a specialization of shared memory iterator. We have probably discussed it before, but worth re-visiting that thought. Happy to schedule something between the three of us to brainstorm this PR further.

}

}
Expand Down
3 changes: 2 additions & 1 deletion include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ struct DefaultMmaTensorOp<
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32,
/// or F16 <= F16 x S4 + F16, F16 <= BF16 x S4 + F32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
Expand Down
167 changes: 157 additions & 10 deletions include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,117 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
return result;
}

};
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 4b (S4/U4)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 4)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;

static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;

using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;

private:
int src_lane_0_, src_lane_1_;
uint32_t byte_selector_0_, byte_selector_10_, byte_selector_11_;
int dst_incr_0_, dst_incr_1_;

public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
int mul;

src_lane_0_ = lane_id ^ 1;
mul = lane_id & 1;
byte_selector_0_ = mul * 0x3715 + (1 - mul) * 0x6240;

src_lane_1_ = lane_id ^ 2;
mul = (lane_id & 2) >> 1;
byte_selector_10_ = mul * 0x7632 + (1 - mul) * 0x5410;
byte_selector_11_ = mul * 0x5410 + (1 - mul) * 0x7632;
dst_incr_0_ = mul * (WarpFragment::kElements / 16);
dst_incr_1_ = (1 - mul) * (WarpFragment::kElements / 16);
}

CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {

WarpFragment result;

MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);

uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[0]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[0]);

// The code assumes that twice more values than needed for a
// F16/BF16 MMA is loaded along contiguous dimension. E.g. in the
// case of column major matrix: threads 0-3 would hold 32 elements
// of the first column in the warp fragment, threads 0-4 32
// elements of the second column, etc.; but only the first 16
// elements of each column will be used for the first MMA
// operation, and the last 16 elements will be used for the
// follow-up MMA operation. This code distributes input values
// across threads so that all of the left (in case of row-major
// matrix) or upper (in case of column-major matrix) half of
// values comes first, and then right/lower half of values comes
// second in corresponding warp fragments. The values are also
// re-distributed between threads so that each value belongs to
// the proper thread for F16/BF16 MMA that will take place after
// the up-casting.

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < WarpFragment::kElements / 16; n++) {
// Exchange values with a neighboring thread, that loaded values
// from the same half of all values as given thread; then,
// combine values in such a way that final values could be
// produced after another exchange.
uint32_t tmp0 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n], src_lane_0_);
uint32_t tmp1 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n + 1], src_lane_0_);
tmp0 = __byte_perm(src_ptr[2 * n], tmp0, byte_selector_0_);
tmp1 = __byte_perm(src_ptr[2 * n + 1], tmp1, byte_selector_0_);

// Exchange values with corresponding thread from the same
// quadruple as given thread, but that loaded values from the
// other half of all values. Then, combine values to produce
// final values hold by given thread.
uint32_t mine = __byte_perm(tmp0, tmp1, byte_selector_10_);
uint32_t theirs = __byte_perm(tmp0, tmp1, byte_selector_11_);
theirs = __shfl_sync(0xFFFFFFFF, theirs, src_lane_1_);
dst_ptr[n + dst_incr_0_] = mine;
dst_ptr[n + dst_incr_1_] = theirs;
}

return result;
}

};

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -414,10 +525,21 @@ class MmaMixedInputTensorOp {

public:

// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionM =
(sizeof_bits<ElementB>::value > sizeof_bits<ElementA>::value)
? 8 * sizeof_bits<ElementB>::value / sizeof_bits<ElementA>::value
: InstructionShape::kM;

// Shape for loading data type from shared memory, accounting
// eventually for narrower ElementA.
using LoadInstructionShapeA =
GemmShape<LoadInstructionM, InstructionShape::kN, InstructionShape::kK>;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are only supporting TN layouts with this. right? I only see adjustment needed for the K-dim (which is going to be contiguous dimension) and not for M-dim or N-dim. The K-dim of the iterator which is on 4bits will need to adjust the load shapeK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to eventually support all layout combinations that are supported for int8, so - yes, I'll have to refine this part.

/// Iterates over the A operand in Shared Memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
MatrixShape<LoadInstructionShapeA::kM, LoadInstructionShapeA::kK>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

/// Storage for A tile in registers (loaded from Shared Memory)
Expand All @@ -430,10 +552,21 @@ class MmaMixedInputTensorOp {
/// Underlying arch::Mma instruction operand fragement for matrix A
using MmaOperandA = typename ArchMmaOperator::FragmentA;

// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK =
(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value)
? 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value
: InstructionShape::kK;

// Shape for loading data type from shared memory, accounting
// eventually for narrower ElementB.
using LoadInstructionShapeB =
GemmShape<InstructionShape::kM, InstructionShape::kN, LoadInstructionK>;

/// Iterates over the B operand in Shared Memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
MatrixShape<LoadInstructionShapeB::kK, LoadInstructionShapeB::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

/// Storage for B tile in registers (loaded from Shared Memory)
Expand Down Expand Up @@ -494,6 +627,13 @@ class MmaMixedInputTensorOp {
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);

if constexpr (is_B_4bit) {
if (!transform_B_flag_) {
ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements;
}
transform_B_flag_ = !transform_B_flag_;
}

CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {

Expand Down Expand Up @@ -524,15 +664,17 @@ class MmaMixedInputTensorOp {
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {

// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<ElementBMma, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);
if (!is_B_4bit || transform_B_flag_) {
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<ElementBMma, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);

// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<ElementBMma, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<ElementBMma, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
}

FragmentA tmp_A;

Expand All @@ -555,6 +697,11 @@ class MmaMixedInputTensorOp {

ptr_dst_A[1] = convert_A(ptr_tmp_A[1]);
}

private:
static constexpr bool is_B_4bit = cutlass::sizeof_bits<ElementB>::value == 4;
static_assert(!is_B_4bit || FragmentB::kElements % 16 == 0);
mutable bool transform_B_flag_ = true;
};

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 2 additions & 0 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ cutlass_test_unit_add_executable(

gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu
gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu

gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu
)

cutlass_test_unit_add_executable(
Expand Down
Loading