Skip to content

Commit

Permalink
Larger sized mma instructions to support full vectorization (#1824)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong authored Jul 30, 2022
1 parent 9bb4cf7 commit e0ae11a
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 33 deletions.
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ void validateMma(Fusion* fusion) {
validateMinimumArch(7, 0);
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
validateMinimumArch(7, 5);

// Check that operands come from ldmatrix, can be
Expand All @@ -1039,6 +1040,7 @@ void validateMma(Fusion* fusion) {
validateTuringMmaInput(mma->inB()->as<TensorView>());
break;
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
validateMinimumArch(8, 0);

// Check that operands come from ldmatrix, can be
Expand Down
23 changes: 21 additions & 2 deletions torch/csrc/jit/codegen/cuda/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ MmaBuilder::MmaBuilder(
case MmaOptions::MacroType::Ampere_16_8_16:
option_.accumulator_stride = outer_stride * 2;
break;
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
option_.accumulator_stride = outer_stride * 4;
break;
default:
TORCH_CHECK(false, "unsupported macro");
break;
Expand Down Expand Up @@ -84,6 +88,8 @@ LoadStoreOpType getLdMatrixType(MmaOptions options) {
switch (options.macro) {
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
// Turing mma assumes TN as default
transpose = (options.operand == MmaOptions::Operand::A &&
!isOperandTransposed(options)) ||
Expand All @@ -109,16 +115,20 @@ bool isVolta(MmaOptions::MacroType macro) {
}

bool isTuring(MmaOptions::MacroType macro) {
return macro == MmaOptions::MacroType::Turing_16_8_16;
return macro == MmaOptions::MacroType::Turing_16_8_16 ||
macro == MmaOptions::MacroType::Turing_16_16_16;
}

bool isAmpere(MmaOptions::MacroType macro) {
return macro == MmaOptions::MacroType::Ampere_16_8_16;
return macro == MmaOptions::MacroType::Ampere_16_8_16 ||
macro == MmaOptions::MacroType::Ampere_16_16_16;
}

int getOutputRegisterSize(MmaOptions::MacroType macro) {
switch (macro) {
case MmaOptions::MacroType::Volta_16_16_4:
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
return 8;
break;
case MmaOptions::MacroType::Turing_16_8_16:
Expand All @@ -138,7 +148,9 @@ int getInputARegisterSize(MmaOptions::MacroType macro) {
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
return 8;
break;
default:
Expand All @@ -156,6 +168,9 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
return 8;
default:
TORCH_INTERNAL_ASSERT(false, "unknown macro");
break;
Expand Down Expand Up @@ -208,6 +223,10 @@ std::string toString(MmaOptions::MacroType mt) {
case MmaOptions::MacroType::Ampere_16_8_16:
ss << "M16N8K16";
break;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
ss << "M16N16K16";
break;
default:
TORCH_INTERNAL_ASSERT(false, "undefined mma type");
break;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ struct MmaOptions {
NoMMA = 0,
Volta_16_16_4,
Ampere_16_8_16,
Ampere_16_16_16,
Turing_16_8_16,
Turing_16_16_16,
Ampere_16_8_8 // place holder for tf32
};

Expand Down
48 changes: 48 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,30 @@ DEVICE_INLINE void M16N8K16TN(
_C[acc_stride + 1] = C_data[3];
}

template <int acc_stride>
DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
float* _C = reinterpret_cast<float*>(accumulator);
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
}

template <int acc_stride = 2>
DEVICE_INLINE void M16N16K16TN(
Array<float, 8, 8>* C,
Array<__half, 8, 8>* A,
Array<__half, 8, 8>* B) {
float* _C = reinterpret_cast<float*>(C);
__half* _B = reinterpret_cast<__half*>(B);
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
}

} // namespace Turing

#endif // Arch 75
Expand Down Expand Up @@ -338,6 +362,30 @@ DEVICE_INLINE void M16N8K16TN(
_C[acc_stride + 1] = C_data[3];
}

template <int acc_stride>
DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
float* _C = reinterpret_cast<float*>(accumulator);
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
}

template <int acc_stride = 2>
DEVICE_INLINE void M16N16K16TN(
Array<float, 8, 8>* C,
Array<__half, 8, 8>* A,
Array<__half, 8, 8>* B) {
float* _C = reinterpret_cast<float*>(C);
__half* _B = reinterpret_cast<__half*>(B);
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
}

} // namespace Ampere

#endif // Arch 80
Expand Down
152 changes: 123 additions & 29 deletions torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput(
setWarpMapped(tv, 4);
}
break;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
scheduleTuringM16N16K16MmaWarpOutput(tv, options);
if (tv->definition()->isA<MmaOp>()) {
setWarpMapped(tv, 4);
}
break;
default:
TORCH_CHECK(
false, "scheduleMmaWarp: unsupported mma option ", toString(macro));
Expand All @@ -151,6 +158,8 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) {
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
scheduleTuringOperandRead(tv, options);
break;
default:
Expand Down Expand Up @@ -505,7 +514,9 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
// Check mma option is supported
TORCH_CHECK(
options.macro == MmaOptions::MacroType::Ampere_16_8_16 ||
options.macro == MmaOptions::MacroType::Turing_16_8_16,
options.macro == MmaOptions::MacroType::Ampere_16_16_16 ||
options.macro == MmaOptions::MacroType::Turing_16_8_16 ||
options.macro == MmaOptions::MacroType::Turing_16_16_16,
"scheduleLdMatrix: unknown macro for ldmatrix");

if (options.operand == MmaOptions::Operand::A) {
Expand Down Expand Up @@ -548,43 +559,80 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);

// validation:
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");

if (transposed) {
// [8, 16]
tv->split(-2, 4);

// [2i, 4i, 16]
tv->reorder({{-1, -2}, {-2, -1}});
// [2i, 16, 4i]
// Each ldmatrix 4 would be loading an effective 16x16x16 tile, which is 2x
// the
// size of regular 16x8x16 tile supported by largest mma operation. The
// swizzle also needs to be different to take this into account.
// TODO:
// Using an emulated 16x16x16 mma tile is a temporary step to enable the
// widest load possible for scheduler bring up phase.
// A unifying step would be needed in a follow up to support all these
// swizzles
// with a single affine utility.
bool use_ldmatrix4 = canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 16);

if (use_ldmatrix4) {
// [... N16, K16]
tv->split(-2, 8);
tv->split(-1, 8);

// -4 -3 -2 -1
// [... N2o, N8, K2o, K8]
tv->reorder({{-3, -2}, {-2, -3}});
// [... N2o, K2o, N8, K8]

if (transposed) {
tv->reorder({{-1, -2}, {-2, -1}});
}

tv->merge(-4);
tv->merge(-3);
// [warp, 4i]
} else {
//[8, 16]
tv->split(-1, 4);
tv->split(-2, 2);

// 0 1 2 3
//[8, oo2,oi2,i4]
tv->reorder({{-4, -2}, {-2, -4}});

// 0 1 2 3
//[oi2, oo2, 8,i4]
// [Warp, K8]
tv->axis(-2)->parallelize(ParallelType::TIDx);
if (is_immediate_output) {
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
} else {
// validation:
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");

if (transposed) {
// [8, 16]
tv->split(-2, 4);

// [2i, 4i, 16]
tv->reorder({{-1, -2}, {-2, -1}});
// [2i, 16, 4i]

tv->merge(-3);
// [warp, 4i]
} else {
//[8, 16]
tv->split(-1, 4);
tv->split(-2, 2);

// 0 1 2 3
//[8, oo2,oi2,i4]
tv->reorder({{-4, -2}, {-2, -4}});

// 0 1 2 3
//[oi2, oo2, 8,i4]

tv->merge(-4);
tv->merge(-3);
// 0 1
//[warp, i4]
}

tv->merge(-4);
tv->merge(-3);
// 0 1
//[warp, i4]
tv->axis(-2)->parallelize(ParallelType::TIDx);
}

tv->axis(-2)->parallelize(ParallelType::TIDx);
} else {
TORCH_INTERNAL_ASSERT(false, "unreachable");
}
Expand Down Expand Up @@ -717,6 +765,52 @@ void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput(
tv->axis(m_pos)->parallelize(ParallelType::TIDx);
}

void WarpMmaSwizzler::scheduleTuringM16N16K16MmaWarpOutput(
TensorView* tv,
const MmaOptions& options) {
// Assume last 2 dims [M16, N8] or [M16, N8, R]
// Locate instruction m
bool is_reduction = tv->axis(-1)->isReduction();

// Make sure instruction tile size is correct.
if (is_reduction) {
validateMmaRootInnerMNK(tv, options, 16, 16, 16);
} else {
validateMmaRootInnerMN(tv, options, 16, 16);
}

int m_pos = is_reduction ? -3 : -2;
// m
// [16, 16 (,R)]

tv->split(m_pos + 1, 8);
// m
// [16, n2, 8 (,R)]
tv->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}});

// m
// [n2, 16, 8 (,R)]
tv->split(m_pos, 8);
tv->split(m_pos + 1, 2);

// m
// [2o, 8o, 4i, 2i (,R)]
tv->merge(m_pos - 1);

// m
// [2o, Warp, 2i (,R)]
TORCH_CHECK(tv->definition() != nullptr);

if (is_reduction && tv->definition()->isA<MmaOp>()) {
// Set instruction loops for mma reduce
for (int pos : c10::irange(5)) {
tv->axis(-pos - 1)->parallelize(ParallelType::Mma);
}
}

tv->axis(m_pos)->parallelize(ParallelType::TIDx);
}

namespace {

bool isMmaInitLoop(const kir::Scope& loop_body) {
Expand Down
18 changes: 16 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,32 @@ class TORCH_CUDA_CU_API WarpMmaSwizzler {
MmaOptions options = MmaOptions());

private:
//! Swizzle implementations for Volta mma.
//! Operand swizzle implementations for Volta mma.
static void scheduleVoltaOperandRead(TensorView* tv, MmaOptions options);

//! Accumulator swizzle implementations for Volta mma.
static void scheduleVoltaM16N16K4Fp32Output(
TensorView* tv,
const MmaOptions& options);

//! Swizzle implementations for Turing mma.
//! Operand swizzle implementations for Turing and Ampere mma.
static void scheduleTuringOperandRead(TensorView* tv, MmaOptions options);

//! Accumulator swizzle implementation for Turing and Ampere mma.
static void scheduleTuringM16N8K16MmaWarpOutput(
TensorView* tv,
const MmaOptions& options);

//! Accumulator swizzle implementation for emulated 16x16x16 mma tile
//! that enables using ldmatrix.x4.
//! Note:
//! Keeping both this option and the ldmatrix.x2 variant above for
//! now for wider scheduler exploration space. Eventually both of
//! these can be unified with a single affine utility.
static void scheduleTuringM16N16K16MmaWarpOutput(
TensorView* tv,
const MmaOptions& options);

//! Utility to lock the transformed dimensions from further transforms.
static void setWarpMapped(TensorView* tv, int number_of_dims);
};
Expand Down
Loading

0 comments on commit e0ae11a

Please sign in to comment.