diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index cc0637f00071c9..43524cdfd30588 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1031,6 +1031,7 @@ void validateMma(Fusion* fusion) { validateMinimumArch(7, 0); break; case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Turing_16_16_16: validateMinimumArch(7, 5); // Check that operands come from ldmatrix, can be @@ -1039,6 +1040,7 @@ void validateMma(Fusion* fusion) { validateTuringMmaInput(mma->inB()->as()); break; case MmaOptions::MacroType::Ampere_16_8_16: + case MmaOptions::MacroType::Ampere_16_16_16: validateMinimumArch(8, 0); // Check that operands come from ldmatrix, can be diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 90e9df51fdefc3..8588d6845554b5 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -32,6 +32,10 @@ MmaBuilder::MmaBuilder( case MmaOptions::MacroType::Ampere_16_8_16: option_.accumulator_stride = outer_stride * 2; break; + case MmaOptions::MacroType::Ampere_16_16_16: + case MmaOptions::MacroType::Turing_16_16_16: + option_.accumulator_stride = outer_stride * 4; + break; default: TORCH_CHECK(false, "unsupported macro"); break; @@ -84,6 +88,8 @@ LoadStoreOpType getLdMatrixType(MmaOptions options) { switch (options.macro) { case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: + case MmaOptions::MacroType::Ampere_16_16_16: + case MmaOptions::MacroType::Turing_16_16_16: // Turing mma assumes TN as default transpose = (options.operand == MmaOptions::Operand::A && !isOperandTransposed(options)) || @@ -109,16 +115,20 @@ bool isVolta(MmaOptions::MacroType macro) { } bool isTuring(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Turing_16_8_16; + return macro == MmaOptions::MacroType::Turing_16_8_16 || + macro == MmaOptions::MacroType::Turing_16_16_16; } bool isAmpere(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Ampere_16_8_16; + return macro == MmaOptions::MacroType::Ampere_16_8_16 || + macro == MmaOptions::MacroType::Ampere_16_16_16; } int getOutputRegisterSize(MmaOptions::MacroType macro) { switch (macro) { case MmaOptions::MacroType::Volta_16_16_4: + case MmaOptions::MacroType::Ampere_16_16_16: + case MmaOptions::MacroType::Turing_16_16_16: return 8; break; case MmaOptions::MacroType::Turing_16_8_16: @@ -138,7 +148,9 @@ int getInputARegisterSize(MmaOptions::MacroType macro) { return 4; break; case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Turing_16_16_16: case MmaOptions::MacroType::Ampere_16_8_16: + case MmaOptions::MacroType::Ampere_16_16_16: return 8; break; default: @@ -156,6 +168,9 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; + case MmaOptions::MacroType::Turing_16_16_16: + case MmaOptions::MacroType::Ampere_16_16_16: + return 8; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); break; @@ -208,6 +223,10 @@ std::string toString(MmaOptions::MacroType mt) { case MmaOptions::MacroType::Ampere_16_8_16: ss << "M16N8K16"; break; + case MmaOptions::MacroType::Turing_16_16_16: + case MmaOptions::MacroType::Ampere_16_16_16: + ss << "M16N16K16"; + break; default: TORCH_INTERNAL_ASSERT(false, "undefined mma type"); break; diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 56bd07de475c28..7874573a3d01b6 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -62,7 +62,9 @@ struct MmaOptions { NoMMA = 0, Volta_16_16_4, Ampere_16_8_16, + Ampere_16_16_16, Turing_16_8_16, + Turing_16_16_16, Ampere_16_8_8 // place holder for tf32 }; diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index c6976c197328d0..4b7f678fa6e889 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -276,6 +276,30 @@ DEVICE_INLINE void M16N8K16TN( _C[acc_stride + 1] = C_data[3]; } +template +DEVICE_INLINE void initM16N16K16TN(Array* accumulator) { + float* _C = reinterpret_cast(accumulator); + initM16N8K16TN(reinterpret_cast*>(&_C[0])); + initM16N8K16TN(reinterpret_cast*>(&_C[2])); +} + +template +DEVICE_INLINE void M16N16K16TN( + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 8, 8>* B) { + float* _C = reinterpret_cast(C); + __half* _B = reinterpret_cast<__half*>(B); + M16N8K16TN( + reinterpret_cast*>(&_C[0]), + A, + reinterpret_cast*>(&_B[0])); + M16N8K16TN( + reinterpret_cast*>(&_C[2]), + A, + reinterpret_cast*>(&_B[4])); +} + } // namespace Turing #endif // Arch 75 @@ -338,6 +362,30 @@ DEVICE_INLINE void M16N8K16TN( _C[acc_stride + 1] = C_data[3]; } +template +DEVICE_INLINE void initM16N16K16TN(Array* accumulator) { + float* _C = reinterpret_cast(accumulator); + initM16N8K16TN(reinterpret_cast*>(&_C[0])); + initM16N8K16TN(reinterpret_cast*>(&_C[2])); +} + +template +DEVICE_INLINE void M16N16K16TN( + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 8, 8>* B) { + float* _C = reinterpret_cast(C); + __half* _B = reinterpret_cast<__half*>(B); + M16N8K16TN( + reinterpret_cast*>(&_C[0]), + A, + reinterpret_cast*>(&_B[0])); + M16N8K16TN( + reinterpret_cast*>(&_C[2]), + A, + reinterpret_cast*>(&_B[4])); +} + } // namespace Ampere #endif // Arch 80 diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index bf41cdfff2606d..6abd4dd56c4731 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -134,6 +134,13 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput( setWarpMapped(tv, 4); } break; + case MmaOptions::MacroType::Turing_16_16_16: + case MmaOptions::MacroType::Ampere_16_16_16: + scheduleTuringM16N16K16MmaWarpOutput(tv, options); + if (tv->definition()->isA()) { + setWarpMapped(tv, 4); + } + break; default: TORCH_CHECK( false, "scheduleMmaWarp: unsupported mma option ", toString(macro)); @@ -151,6 +158,8 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { break; case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: + case MmaOptions::MacroType::Turing_16_16_16: + case MmaOptions::MacroType::Ampere_16_16_16: scheduleTuringOperandRead(tv, options); break; default: @@ -505,7 +514,9 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { // Check mma option is supported TORCH_CHECK( options.macro == MmaOptions::MacroType::Ampere_16_8_16 || - options.macro == MmaOptions::MacroType::Turing_16_8_16, + options.macro == MmaOptions::MacroType::Ampere_16_16_16 || + options.macro == MmaOptions::MacroType::Turing_16_8_16 || + options.macro == MmaOptions::MacroType::Turing_16_16_16, "scheduleLdMatrix: unknown macro for ldmatrix"); if (options.operand == MmaOptions::Operand::A) { @@ -548,43 +559,80 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N); auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K); - // validation: - TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - if (transposed) { - // [8, 16] - tv->split(-2, 4); - - // [2i, 4i, 16] - tv->reorder({{-1, -2}, {-2, -1}}); - // [2i, 16, 4i] + // Each ldmatrix 4 would be loading an effective 16x16x16 tile, which is 2x + // the + // size of regular 16x8x16 tile supported by largest mma operation. The + // swizzle also needs to be different to take this into account. + // TODO: + // Using an emulated 16x16x16 mma tile is a temporary step to enable the + // widest load possible for scheduler bring up phase. + // A unifying step would be needed in a follow up to support all these + // swizzles + // with a single affine utility. + bool use_ldmatrix4 = canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 16); + + if (use_ldmatrix4) { + // [... N16, K16] + tv->split(-2, 8); + tv->split(-1, 8); + + // -4 -3 -2 -1 + // [... N2o, N8, K2o, K8] + tv->reorder({{-3, -2}, {-2, -3}}); + // [... N2o, K2o, N8, K8] + + if (transposed) { + tv->reorder({{-1, -2}, {-2, -1}}); + } + tv->merge(-4); tv->merge(-3); - // [warp, 4i] - } else { - //[8, 16] - tv->split(-1, 4); - tv->split(-2, 2); - - // 0 1 2 3 - //[8, oo2,oi2,i4] - tv->reorder({{-4, -2}, {-2, -4}}); - // 0 1 2 3 - //[oi2, oo2, 8,i4] + // [Warp, K8] + tv->axis(-2)->parallelize(ParallelType::TIDx); + if (is_immediate_output) { + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + } else { + // validation: + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + + if (transposed) { + // [8, 16] + tv->split(-2, 4); + + // [2i, 4i, 16] + tv->reorder({{-1, -2}, {-2, -1}}); + // [2i, 16, 4i] + + tv->merge(-3); + // [warp, 4i] + } else { + //[8, 16] + tv->split(-1, 4); + tv->split(-2, 2); + + // 0 1 2 3 + //[8, oo2,oi2,i4] + tv->reorder({{-4, -2}, {-2, -4}}); + + // 0 1 2 3 + //[oi2, oo2, 8,i4] + + tv->merge(-4); + tv->merge(-3); + // 0 1 + //[warp, i4] + } - tv->merge(-4); - tv->merge(-3); - // 0 1 - //[warp, i4] + tv->axis(-2)->parallelize(ParallelType::TIDx); } - - tv->axis(-2)->parallelize(ParallelType::TIDx); } else { TORCH_INTERNAL_ASSERT(false, "unreachable"); } @@ -717,6 +765,52 @@ void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput( tv->axis(m_pos)->parallelize(ParallelType::TIDx); } +void WarpMmaSwizzler::scheduleTuringM16N16K16MmaWarpOutput( + TensorView* tv, + const MmaOptions& options) { + // Assume last 2 dims [M16, N8] or [M16, N8, R] + // Locate instruction m + bool is_reduction = tv->axis(-1)->isReduction(); + + // Make sure instruction tile size is correct. + if (is_reduction) { + validateMmaRootInnerMNK(tv, options, 16, 16, 16); + } else { + validateMmaRootInnerMN(tv, options, 16, 16); + } + + int m_pos = is_reduction ? -3 : -2; + // m + // [16, 16 (,R)] + + tv->split(m_pos + 1, 8); + // m + // [16, n2, 8 (,R)] + tv->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}}); + + // m + // [n2, 16, 8 (,R)] + tv->split(m_pos, 8); + tv->split(m_pos + 1, 2); + + // m + // [2o, 8o, 4i, 2i (,R)] + tv->merge(m_pos - 1); + + // m + // [2o, Warp, 2i (,R)] + TORCH_CHECK(tv->definition() != nullptr); + + if (is_reduction && tv->definition()->isA()) { + // Set instruction loops for mma reduce + for (int pos : c10::irange(5)) { + tv->axis(-pos - 1)->parallelize(ParallelType::Mma); + } + } + + tv->axis(m_pos)->parallelize(ParallelType::TIDx); +} + namespace { bool isMmaInitLoop(const kir::Scope& loop_body) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h index ef51b64d7095bf..03cbea6d3cffc6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h @@ -115,18 +115,32 @@ class TORCH_CUDA_CU_API WarpMmaSwizzler { MmaOptions options = MmaOptions()); private: - //! Swizzle implementations for Volta mma. + //! Operand swizzle implementations for Volta mma. static void scheduleVoltaOperandRead(TensorView* tv, MmaOptions options); + + //! Accumulator swizzle implementations for Volta mma. static void scheduleVoltaM16N16K4Fp32Output( TensorView* tv, const MmaOptions& options); - //! Swizzle implementations for Turing mma. + //! Operand swizzle implementations for Turing and Ampere mma. static void scheduleTuringOperandRead(TensorView* tv, MmaOptions options); + + //! Accumulator swizzle implementation for Turing and Ampere mma. static void scheduleTuringM16N8K16MmaWarpOutput( TensorView* tv, const MmaOptions& options); + //! Accumulator swizzle implementation for emulated 16x16x16 mma tile + //! that enables using ldmatrix.x4. + //! Note: + //! Keeping both this option and the ldmatrix.x2 variant above for + //! now for wider scheduler exploration space. Eventually both of + //! these can be unified with a single affine utility. + static void scheduleTuringM16N16K16MmaWarpOutput( + TensorView* tv, + const MmaOptions& options); + //! Utility to lock the transformed dimensions from further transforms. static void setWarpMapped(TensorView* tv, int number_of_dims); }; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 57ee1ae890ddcf..860ec7af046918 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -3344,6 +3344,322 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +// Matmul test on Ampere using ldmatrix.x4 to load operands +TEST_F(NVFuserTest, FusionAmpereMatmulTNLarge_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + // Call scheduler + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Ampere using ldmatrix.x4 to load operands +TEST_F(NVFuserTest, FusionAmpereMatmulNTLarge_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [K,M] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Ampere using ldmatrix.x4 to load operands +TEST_F(NVFuserTest, FusionAmpereMatmulTTLarge_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing using ldmatrix.x4 to load operands +TEST_F(NVFuserTest, FusionTuringMatmulTNLarge_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing using ldmatrix.x4 to load operands +TEST_F(NVFuserTest, FusionTuringMatmulTTLarge_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + mma_builder.configureMma(tv2); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing using ldmatrix.x4 to load operands +TEST_F(NVFuserTest, FusionTuringMatmulNTLarge_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [K,M] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit