diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 8c776be086f1d..35190ea6f0908 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -171,8 +171,26 @@ namespace { // to the given mma dimension. See [MMA dimension matching]. std::vector getMmaDomains(MmaOp* mma, MmaDimension dimension) { // This utility is user facing so shouldn't ever see tensor index here. - auto accumulator_domain = - mma->out()->as()->getMaybeRFactorDomain(); + + // Note: [Use Root Domain in Accumulator TV] + // Have to use root domain for accumulator tv since the operands do not have + // root/rfactor domains that map to the rfactor domain of output. + // For example: + // C[I,I,R,R] = mma (A[I,B,I,I], B[B,I,I,I]), + // if we do + // c->split(-1,4); + // c->rfactor(-1); + // on the mma stage we get: + // C[I,I,R,Io,R(4)] = mma (A[I,B,I,I], B[B,I,I,I]), + // and in this case Io and R(4) would not be able to find root mapping + // in A or B. + // + // Essentially in the case of rfactor, this utility does producer side + // matching so looking at root domain would be required. + // This matching pattern should support most common matmul applications, + // but in follow ups we may need to extend RFactor matching if there + // are more complex scheduling patterns that we want to support. + auto accumulator_domain = mma->out()->as()->getRootDomain(); auto a_domain = TensorDomain::noReductions( mma->inA()->as()->getMaybeRFactorDomain()); auto b_domain = TensorDomain::noReductions( @@ -269,10 +287,17 @@ std::vector getMmaRootDimensions( std::vector result; + // Need to use root domain for accumulator tv and maybe rfactor domain + // otherwise. See [Use Root Domain in Accumulator TV]. + auto is_mma_output = + tv->definition() != nullptr && tv->definition()->isA(); + const auto& tv_root_domain = + is_mma_output ? tv->getRootDomain() : tv->getMaybeRFactorDomain(); + // Loop through tensorview's root domains and accumulate all the // root domain IterDomain's that maps to any of the collected // mma root dimension from the mma accumulator tv. - for (auto tv_id : tv->getMaybeRFactorDomain()) { + for (auto tv_id : tv_root_domain) { if (std::any_of( mma_root_dimensions.begin(), mma_root_dimensions.end(), @@ -483,7 +508,8 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain", + tv->toString()); //[16m, 16k] tv->split(-2, 8); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 56310f226fde6..20869353c201f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -1499,32 +1499,64 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { auto instruction_tile = tile.instruction_tile; TORCH_CHECK( - warp_tile.k == cta_tile.k, - "schedule warp tile: currently no support for splitting k dimension to different warps"); + cta_tile.k % warp_tile.k == 0, + "Number of warp on k dimension need to be integer"); + + int num_warp_k = cta_tile.k / warp_tile.k; mma_util::checkDimSize( tv, {-3, -2, -1}, {cta_tile.m, cta_tile.n, cta_tile.k}); - // -3 -2 -1 - //[... M, N, K] - - // Distribute warp tile: - tv->split(-3, warp_tile.m); - tv->split(-2, warp_tile.n); + if (num_warp_k == 1) { + // Non split K over warp case: - // -5 -4 -3 -2 -1 - // [Mwo Mw Nwo Nw K] - tv->split(-4, instruction_tile.m); - tv->split(-2, instruction_tile.n); - tv->split(-1, instruction_tile.k); + // -3 -2 -1 + //[... M, N, K] + // Distribute warp tile: + tv->split(-3, warp_tile.m); + tv->split(-2, warp_tile.n); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mwo Mw Mi Nwo Nw Ni Ko Ki] + // -5 -4 -3 -2 -1 + // [Mwo Mw Nwo Nw K] + tv->split(-4, instruction_tile.m); + tv->split(-2, instruction_tile.n); + tv->split(-1, instruction_tile.k); - tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni Ko Ki] - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + } else { + // Split K over warp case: + // Main difference is that an additional + // thread dimension needs to be reserved + // for cross warp reduction: + // -3 -2 -1 + //[... M, N, K] + // Distribute warp tile: + tv->split(-3, warp_tile.m); + tv->split(-2, warp_tile.n); + tv->split(-1, warp_tile.k); + + // -6 -5 -4 -3 -2 -1 + // [Mwo Mw Nwo Nw K, Kw] + tv->split(-5, instruction_tile.m); + tv->split(-3, instruction_tile.n); + tv->split(-1, instruction_tile.k); + + // -9 -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni Kwo Kw Ki] + + tv->reorder({{-8, -6}, {-7, -3}, {-6, -8}, {-4, -2}, {-3, -7}, {-2, -4}}); + // -9 -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Kw, Mi Ni Ki] + + tv->merge(-9); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [MNwo Ko Mw Nw Kw, Mi Ni Ki] + } } void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { @@ -1536,6 +1568,12 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { mma_util::checkDimSize(tv, {-2, -1}, {cta_tile.m, cta_tile.n}); + TORCH_CHECK( + cta_tile.k % warp_tile.k == 0, + "Number of warp on k dimension need to be integer"); + + int num_warp_k = cta_tile.k / warp_tile.k; + // -2 -1 //[... M, N] @@ -1555,6 +1593,14 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { // -6 -5 -4 -3 -2 -1 // [Mwo Nwo Mw Nw Mi Ni] + + if (num_warp_k != 1) { + // The non reduction warps are merged together + // to save one thread dim for cross dim reduce. + tv->merge(-6); + // -5 -4 -3 -2 -1 + // [MNo Mw Nw Mi Ni] + } } //! Split the innermost dim to a vectorized load @@ -1568,9 +1614,21 @@ void scheduleContiguousVectorLoad( tv->split(-1, num_of_thread * vector_word); tv->split(-1, vector_word); // [..., thread, vec] - // distribute to warp: + // distribute to warp: for tidx tv->split(-2, 32); - tv->split(-3, warp_dims.n * warp_dims.k); + + // -3 -2 -1 + // [...warp, lane, vec] + + if (warp_dims.k == 1) { + // -4 -3 -2 -1 + // [...warpM, warpN, lane, vec] + tv->split(-3, warp_dims.n); + } else { + // -4 -3 -2 -1 + // [...warpMN, warpR, lane, vec] + tv->split(-3, warp_dims.k); + } tv->axis(-1)->parallelize(ParallelType::Vectorize); tv->axis(-2)->parallelize(ParallelType::TIDx); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 71983b1f162c9..1ee69ea034e56 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -581,11 +581,11 @@ TensorView* TensorView::rFactor(const std::vector& axes) { // !hasComputeAt(), "Cannot rfactor tensors after compute at has been // set."); TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView"); - TORCH_INTERNAL_ASSERT(definition()->isA()); FusionGuard fg(fusion()); TORCH_CHECK( definition() != nullptr && - definition()->getExprType() == ExprType::ReductionOp, + definition()->getExprType() == ExprType::ReductionOp || + definition()->getExprType() == ExprType::MmaOp, "Error rfactoring ", this, " its definition is either a nullptr or not a reduction."); @@ -596,8 +596,6 @@ TensorView* TensorView::rFactor(const std::vector& axes) { !definition()->isA(), "For GroupedReducitonOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); - ReductionOp* this_definition = definition()->as(); - // Split tensor view into 2 parts auto domain_pair = domain()->rFactor(axes); @@ -614,21 +612,38 @@ TensorView* TensorView::rFactor(const std::vector& axes) { setDomain(consumer_domain); TensorView* consumer = this; - // Setup dependency chain, inserting producer before this op. - // Expr* producer_definition = - IrBuilder::create( - this_definition->getReductionOpType(), - this_definition->init(), - producer, - this_definition->in()); - - // Expr* consumer_definition = - IrBuilder::create( - this_definition->getReductionOpType(), - this_definition->init(), - consumer, - producer); + if (auto this_reduction = dynamic_cast(definition())) { + // Setup dependency chain, inserting producer before this op. + // Expr* producer_definition = + IrBuilder::create( + this_reduction->getReductionOpType(), + this_reduction->init(), + producer, + this_reduction->in()); + // Expr* consumer_definition = + IrBuilder::create( + this_reduction->getReductionOpType(), + this_reduction->init(), + consumer, + producer); + } else if (auto this_mma = dynamic_cast(definition())) { + // Initial reduction that still uses mma to combine + // the input. + IrBuilder::create( + producer, + this_mma->inA(), + this_mma->inB(), + this_mma->init(), + this_mma->options()); + + // Remaining reduction that can be scheduled cross + // warp or cta. + IrBuilder::create( + BinaryOpType::Add, this_mma->init(), consumer, producer); + } else { + TORCH_INTERNAL_ASSERT(false, "RFactor: unsupported tensor definition"); + } return producer; } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 892e6943c1e39..fb4591ee5ada1 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -3242,6 +3242,499 @@ TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(ref, 0.0001, 0.0001)); } +// Matmul test on Ampere +TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + int Ko = 11, Ki = 8; + + // [M,Ko,Ki] + auto tv0 = makeContigTensor(3, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv0_view = view(tv0, {M, Ko, Ki}, {M, K}); + + // [M,N,K] + auto tv0b = broadcast(tv0_view, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0_view->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + tv0_view->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Inline the view op with the shared mem write minus + // the vectorization axes for now. + tv0_view->computeAt(tv0cw, -2); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, Ko, Ki}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = + at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Initial test case for in-CTA split K with VoltaMMA +TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossWarp_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 120, N = 264, K = 120; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 16); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + auto tv2c_rf = tv2c->rFactor({-9, -4, -1}); + + // tv2c_rf is the actual output of the mma op after + // Rfactoring. + mma_builder.accumulatorTv(tv2c_rf); + + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c_rf, -4); + tv1cr->computeAt(tv2c_rf, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({ + {-3, -2}, + {-2, -3}, + }); + // [Mo,No,Ko,N,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c_rf->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); + tv2c_rf->axis(3)->parallelize(ParallelType::TIDz); + tv2c_rf->axis(4)->parallelize(ParallelType::TIDy); + + tv2c->axis(2)->parallelize(ParallelType::TIDz); + tv2c->axis(3)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Initial test case for cross-CTA split K with VoltaMMA +TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 120, N = 264, K = 120; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->split(-2, 2, true); + // Order K + // 0 1 2 3 4 5 6 + // [Mo,No, M128, N128, Ko, K2CTA, K32] + tv2c->reorder({{2, 4}, {3, 5}, {4, 3}, {5, 2}}); + // 0 1 2 3 4 5 6 + // [Mo,No, K2CTA, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 4); + tv1r->computeAt(tv2c, 4); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + auto tv2c_rf = tv2c->rFactor({-9, -6, -1}); + + // tv2c_rf is the actual output of the mma op after + // Rfactoring. + mma_builder.accumulatorTv(tv2c_rf); + + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c_rf, -4); + tv1cr->computeAt(tv2c_rf, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({ + {-3, -2}, + {-2, -3}, + }); + // [Mo,No,Ko,N,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c_rf->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); + tv2c_rf->axis(2)->parallelize(ParallelType::BIDz); + tv2c_rf->axis(4)->parallelize(ParallelType::TIDz); + tv2c_rf->axis(5)->parallelize(ParallelType::TIDy); + + tv2c->axis(0)->parallelize(ParallelType::BIDx); + tv2c->axis(1)->parallelize(ParallelType::BIDy); + tv2c->axis(2)->parallelize(ParallelType::BIDz); + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit