Skip to content

Commit

Permalink
Add support for mixed 4-bit/8-bit data types GEMM (NVIDIA#1413)
Browse files Browse the repository at this point in the history
* Add support for mixed 4-bit/8-bit data types GEMM

* fix ( and )

---------

Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
3 people authored and Jiayu Sun committed Sep 4, 2024
1 parent ace7710 commit fcc04df
Show file tree
Hide file tree
Showing 15 changed files with 960 additions and 14 deletions.
54 changes: 54 additions & 0 deletions include/cutlass/gemm/device/default_gemm_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,60 @@ struct DefaultGemmConfigurationSm89F8 {
using Operator = arch::OpMultiplyAdd;
};

////////////////////////////////////////////////////////////////////////////////

template <
typename ElementC>
struct DefaultGemmConfiguration<
arch::OpClassTensorOp,
arch::Sm80,
int4b_t,
int8_t,
ElementC,
int32_t> {

static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;

using ThreadblockShape = GemmShape<128, 256, 64>;
using WarpShape = GemmShape<64, 64, 64>;
using InstructionShape = GemmShape<16, 8, 32>;
static int const kStages = 3;

using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

using Operator = arch::OpMultiplyAddSaturate;
};

////////////////////////////////////////////////////////////////////////////////

template <
typename ElementC>
struct DefaultGemmConfiguration<
arch::OpClassTensorOp,
arch::Sm80,
int8_t,
int4b_t,
ElementC,
int32_t> {

static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;

using ThreadblockShape = GemmShape<128, 256, 64>;
using WarpShape = GemmShape<64, 64, 64>;
using InstructionShape = GemmShape<16, 8, 32>;
static int const kStages = 3;

using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

using Operator = arch::OpMultiplyAddSaturate;
};

////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for SM89 fe4m3 x fe4m3
template <typename ElementC, typename ElementAccumulator>
struct DefaultGemmConfiguration<
Expand Down
71 changes: 70 additions & 1 deletion include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ struct DefaultMmaTensorOp<
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
ElementA, ElementB>::type;

// Operand datatypes in the internal MMA instruction - use the wider of the two data types
Expand All @@ -294,6 +294,75 @@ struct DefaultMmaTensorOp<
Policy, PartitionsK, AccumulatorsInRowMajor>;
};


/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Element type of A matrix
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Element type of B matrix
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 32>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {


// Check if the ElementA and ElementB are of different data types
static_assert(!platform::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
ElementA, ElementB>::type;

// Operand datatypes in the internal MMA instruction - use the wider of the two data types
using MmaElementA = ElementOperand;
using MmaElementB = ElementOperand;
using MmaElementC = ElementC;

// Uses
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
GemmShape<16, 8, 32>,
32,
MmaElementA, cutlass::layout::RowMajor,
MmaElementB, cutlass::layout::ColumnMajor,
MmaElementC, cutlass::layout::RowMajor,
arch::OpMultiplyAddSaturate
>,
cutlass::MatrixShape<1, 1> >;

// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace warp
Expand Down
14 changes: 10 additions & 4 deletions include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ struct FragmentShuffler {
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
/// for operand A multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
Expand All @@ -122,8 +123,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kA,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)) ||
((sizeof_bits<ElementMma_>::value == 8) &&
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
Expand Down Expand Up @@ -187,6 +190,7 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
Expand All @@ -205,8 +209,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)) ||
((sizeof_bits<ElementMma_>::value == 8) &&
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
Expand Down
80 changes: 80 additions & 0 deletions include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,86 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
}
};

/// Partial specialization for Array<int8_t, 8> <= Array<int4b_t, 8>
template <
FloatRoundStyle Round
>
struct NumericArrayConverter<int8_t, int4b_t, 8, Round> {

using result_type = Array<int8_t, 8>;
using source_type = Array<int4b_t, 8>;
static FloatRoundStyle const round_style = Round;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {

unsigned const& storage = reinterpret_cast<unsigned const &>(source);
unsigned out[2];

asm volatile(
"{ .reg .u32 tmp0, tmp1, tmp2;"
"shl.b32 tmp0, %2, 4;"
"and.b32 tmp0, tmp0, 0xf0f0f0f0;"
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
"shr.u32 tmp0, tmp0, 4;"
"or.b32 tmp2, tmp0, tmp1;"
"and.b32 tmp0, %2, 0xf0f0f0f0;"
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
"shr.u32 tmp0, tmp0, 4;"
"or.b32 tmp0, tmp0, tmp1;"
"prmt.b32 %0, tmp2, tmp0, 0x5140;"
"prmt.b32 %1, tmp2, tmp0, 0x7362;"
"}"
: "=r"(out[0]), "=r"(out[1])
: "r"(storage));

return reinterpret_cast<result_type const &>(out);
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/// Partial specialization for Array<int8_t> <= Array<int4b_t>
template <
int N,
FloatRoundStyle Round
>
struct NumericArrayConverter<int8_t, int4b_t, N, Round> {
static_assert(!(N % 8), "N must be multiple of 8.");

using result_type = Array<int8_t, N>;
using source_type = Array<int4b_t, N>;
static FloatRoundStyle const round_style = Round;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {

NumericArrayConverter<int8_t, int4b_t, 8, Round> convert_vector_;

result_type result;

Array<int8_t, 8> *result_ptr = reinterpret_cast<Array<int8_t, 8> *>(&result);
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

#endif // Conditional guards to enable partial specialization for packed integers

namespace detail {
Expand Down
Loading

0 comments on commit fcc04df

Please sign in to comment.