Skip to content

Commit

Permalink
MMA Rfactor support for cross-warp and cross-CTA split on K dimension (
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Jul 2, 2022
1 parent 76b3cca commit f008140
Show file tree
Hide file tree
Showing 4 changed files with 634 additions and 42 deletions.
34 changes: 30 additions & 4 deletions torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,26 @@ namespace {
// to the given mma dimension. See [MMA dimension matching].
std::vector<IterDomain*> getMmaDomains(MmaOp* mma, MmaDimension dimension) {
// This utility is user facing so shouldn't ever see tensor index here.
auto accumulator_domain =
mma->out()->as<TensorView>()->getMaybeRFactorDomain();

// Note: [Use Root Domain in Accumulator TV]
// Have to use root domain for accumulator tv since the operands do not have
// root/rfactor domains that map to the rfactor domain of output.
// For example:
// C[I,I,R,R] = mma (A[I,B,I,I], B[B,I,I,I]),
// if we do
// c->split(-1,4);
// c->rfactor(-1);
// on the mma stage we get:
// C[I,I,R,Io,R(4)] = mma (A[I,B,I,I], B[B,I,I,I]),
// and in this case Io and R(4) would not be able to find root mapping
// in A or B.
//
// Essentially in the case of rfactor, this utility does producer side
// matching so looking at root domain would be required.
// This matching pattern should support most common matmul applications,
// but in follow ups we may need to extend RFactor matching if there
// are more complex scheduling patterns that we want to support.
auto accumulator_domain = mma->out()->as<TensorView>()->getRootDomain();
auto a_domain = TensorDomain::noReductions(
mma->inA()->as<TensorView>()->getMaybeRFactorDomain());
auto b_domain = TensorDomain::noReductions(
Expand Down Expand Up @@ -269,10 +287,17 @@ std::vector<IterDomain*> getMmaRootDimensions(

std::vector<IterDomain*> result;

// Need to use root domain for accumulator tv and maybe rfactor domain
// otherwise. See [Use Root Domain in Accumulator TV].
auto is_mma_output =
tv->definition() != nullptr && tv->definition()->isA<MmaOp>();
const auto& tv_root_domain =
is_mma_output ? tv->getRootDomain() : tv->getMaybeRFactorDomain();

// Loop through tensorview's root domains and accumulate all the
// root domain IterDomain's that maps to any of the collected
// mma root dimension from the mma accumulator tv.
for (auto tv_id : tv->getMaybeRFactorDomain()) {
for (auto tv_id : tv_root_domain) {
if (std::any_of(
mma_root_dimensions.begin(),
mma_root_dimensions.end(),
Expand Down Expand Up @@ -483,7 +508,8 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain",
tv->toString());

//[16m, 16k]
tv->split(-2, 8);
Expand Down
98 changes: 78 additions & 20 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1499,32 +1499,64 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
auto instruction_tile = tile.instruction_tile;

TORCH_CHECK(
warp_tile.k == cta_tile.k,
"schedule warp tile: currently no support for splitting k dimension to different warps");
cta_tile.k % warp_tile.k == 0,
"Number of warp on k dimension need to be integer");

int num_warp_k = cta_tile.k / warp_tile.k;

mma_util::checkDimSize(
tv, {-3, -2, -1}, {cta_tile.m, cta_tile.n, cta_tile.k});

// -3 -2 -1
//[... M, N, K]

// Distribute warp tile:
tv->split(-3, warp_tile.m);
tv->split(-2, warp_tile.n);
if (num_warp_k == 1) {
// Non split K over warp case:

// -5 -4 -3 -2 -1
// [Mwo Mw Nwo Nw K]
tv->split(-4, instruction_tile.m);
tv->split(-2, instruction_tile.n);
tv->split(-1, instruction_tile.k);
// -3 -2 -1
//[... M, N, K]
// Distribute warp tile:
tv->split(-3, warp_tile.m);
tv->split(-2, warp_tile.n);

// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Mw Mi Nwo Nw Ni Ko Ki]
// -5 -4 -3 -2 -1
// [Mwo Mw Nwo Nw K]
tv->split(-4, instruction_tile.m);
tv->split(-2, instruction_tile.n);
tv->split(-1, instruction_tile.k);

tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}});
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Mw Mi Nwo Nw Ni Ko Ki]

// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Nwo Ko Mw Nw Mi Ni Ki]
tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}});
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Nwo Ko Mw Nw Mi Ni Ki]
} else {
// Split K over warp case:
// Main difference is that an additional
// thread dimension needs to be reserved
// for cross warp reduction:
// -3 -2 -1
//[... M, N, K]
// Distribute warp tile:
tv->split(-3, warp_tile.m);
tv->split(-2, warp_tile.n);
tv->split(-1, warp_tile.k);

// -6 -5 -4 -3 -2 -1
// [Mwo Mw Nwo Nw K, Kw]
tv->split(-5, instruction_tile.m);
tv->split(-3, instruction_tile.n);
tv->split(-1, instruction_tile.k);

// -9 -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Mw Mi Nwo Nw Ni Kwo Kw Ki]

tv->reorder({{-8, -6}, {-7, -3}, {-6, -8}, {-4, -2}, {-3, -7}, {-2, -4}});
// -9 -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Nwo Ko Mw Nw Kw, Mi Ni Ki]

tv->merge(-9);
// -8 -7 -6 -5 -4 -3 -2 -1
// [MNwo Ko Mw Nw Kw, Mi Ni Ki]
}
}

void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) {
Expand All @@ -1536,6 +1568,12 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) {

mma_util::checkDimSize(tv, {-2, -1}, {cta_tile.m, cta_tile.n});

TORCH_CHECK(
cta_tile.k % warp_tile.k == 0,
"Number of warp on k dimension need to be integer");

int num_warp_k = cta_tile.k / warp_tile.k;

// -2 -1
//[... M, N]

Expand All @@ -1555,6 +1593,14 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) {

// -6 -5 -4 -3 -2 -1
// [Mwo Nwo Mw Nw Mi Ni]

if (num_warp_k != 1) {
// The non reduction warps are merged together
// to save one thread dim for cross dim reduce.
tv->merge(-6);
// -5 -4 -3 -2 -1
// [MNo Mw Nw Mi Ni]
}
}

//! Split the innermost dim to a vectorized load
Expand All @@ -1568,9 +1614,21 @@ void scheduleContiguousVectorLoad(
tv->split(-1, num_of_thread * vector_word);
tv->split(-1, vector_word);
// [..., thread, vec]
// distribute to warp:
// distribute to warp: for tidx
tv->split(-2, 32);
tv->split(-3, warp_dims.n * warp_dims.k);

// -3 -2 -1
// [...warp, lane, vec]

if (warp_dims.k == 1) {
// -4 -3 -2 -1
// [...warpM, warpN, lane, vec]
tv->split(-3, warp_dims.n);
} else {
// -4 -3 -2 -1
// [...warpMN, warpR, lane, vec]
tv->split(-3, warp_dims.k);
}

tv->axis(-1)->parallelize(ParallelType::Vectorize);
tv->axis(-2)->parallelize(ParallelType::TIDx);
Expand Down
51 changes: 33 additions & 18 deletions torch/csrc/jit/codegen/cuda/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,11 +581,11 @@ TensorView* TensorView::rFactor(const std::vector<int>& axes) {
// !hasComputeAt(), "Cannot rfactor tensors after compute at has been
// set.");
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
TORCH_INTERNAL_ASSERT(definition()->isA<ReductionOp>());
FusionGuard fg(fusion());
TORCH_CHECK(
definition() != nullptr &&
definition()->getExprType() == ExprType::ReductionOp,
definition()->getExprType() == ExprType::ReductionOp ||
definition()->getExprType() == ExprType::MmaOp,
"Error rfactoring ",
this,
" its definition is either a nullptr or not a reduction.");
Expand All @@ -596,8 +596,6 @@ TensorView* TensorView::rFactor(const std::vector<int>& axes) {
!definition()->isA<GroupedReductionOp>(),
"For GroupedReducitonOp, use TensorView::rFactor(const std::vector<int>& axes, const std::vector<TensorView*>& tvs)");

ReductionOp* this_definition = definition()->as<ReductionOp>();

// Split tensor view into 2 parts
auto domain_pair = domain()->rFactor(axes);

Expand All @@ -614,21 +612,38 @@ TensorView* TensorView::rFactor(const std::vector<int>& axes) {
setDomain(consumer_domain);
TensorView* consumer = this;

// Setup dependency chain, inserting producer before this op.
// Expr* producer_definition =
IrBuilder::create<ReductionOp>(
this_definition->getReductionOpType(),
this_definition->init(),
producer,
this_definition->in());

// Expr* consumer_definition =
IrBuilder::create<ReductionOp>(
this_definition->getReductionOpType(),
this_definition->init(),
consumer,
producer);
if (auto this_reduction = dynamic_cast<ReductionOp*>(definition())) {
// Setup dependency chain, inserting producer before this op.
// Expr* producer_definition =
IrBuilder::create<ReductionOp>(
this_reduction->getReductionOpType(),
this_reduction->init(),
producer,
this_reduction->in());

// Expr* consumer_definition =
IrBuilder::create<ReductionOp>(
this_reduction->getReductionOpType(),
this_reduction->init(),
consumer,
producer);
} else if (auto this_mma = dynamic_cast<MmaOp*>(definition())) {
// Initial reduction that still uses mma to combine
// the input.
IrBuilder::create<MmaOp>(
producer,
this_mma->inA(),
this_mma->inB(),
this_mma->init(),
this_mma->options());

// Remaining reduction that can be scheduled cross
// warp or cta.
IrBuilder::create<ReductionOp>(
BinaryOpType::Add, this_mma->init(), consumer, producer);
} else {
TORCH_INTERNAL_ASSERT(false, "RFactor: unsupported tensor definition");
}
return producer;
}

Expand Down
Loading

0 comments on commit f008140

Please sign in to comment.