diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 96f4a20b88401..7fe057d7383de 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2414,11 +2414,9 @@ class CudaKernelGenerator : private OptOutConstDispatch { void handle(const kir::IntPair* int_pair) { const auto def = int_pair->definition(); - if (print_inline_) { - code_ << gen(def); - } else { - code_ << varName(int_pair); - } + TORCH_INTERNAL_ASSERT( + def != nullptr, "no support for un-inlined int pair yet."); + code_ << gen(def); } void handle(const kir::PairSelect* pair_select) { diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 341b8a2a518f5..8989ad06a234b 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -527,13 +527,43 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) { const auto out_x_ind = out_x_it->second; const auto out_y_ind = out_y_it->second; - // Actual swizzle operation is handled via IndexSwizzle pass - // all behavior in this pass is directly forward through the - // index and extent. - index_map_[in_x_id] = out_x_ind; - index_map_[in_y_id] = out_y_ind; - extent_map_[in_y_id] = getExtent(out_y_id); - extent_map_[in_x_id] = getExtent(out_x_id); + if (swizzle_mode_ == SwizzleMode::NoSwizzle || + swizzle_mode_ != swizzle_2d->swizzleMode()) { + // Handle inactive swizzles by just passing through index + // and extend information. + + TORCH_INTERNAL_ASSERT( + index_map_.count(in_x_id) == index_map_.count(in_y_id), + "input index should be either both defined or both undefined"); + if (index_map_.count(in_x_id)) { + // Only propagate original index through if + // the input index hasn't been computed. + // TODO: + // This part should be cleaner once we remove the + // second index traversal pass. + return; + } + index_map_[in_x_id] = out_x_ind; + index_map_[in_y_id] = out_y_ind; + extent_map_[in_y_id] = getExtent(out_y_id); + extent_map_[in_x_id] = getExtent(out_x_id); + } else { + // Generate integer swizzle math if the + // swizzle is activated. See also + // [Note on swizzle mode]. + + auto out_pair = IrBuilder::swizzle2DIntExpr( + out_x_ind, + out_y_ind, + getExtent(out_x_id), + getExtent(out_y_id), + swizzle_2d->swizzleType()); + + index_map_[in_x_id] = + IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X); + index_map_[in_y_id] = + IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y); + } } void IndexCompute::handle(Expr* e) { @@ -616,9 +646,31 @@ IndexCompute::IndexCompute( reference_halo_extent_map_(std::move(reference_halo_extent_map)) { FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute"); concrete_id_pass_ = true; + swizzle_mode_ = SwizzleMode::Loop; } void IndexCompute::run(const LoopIndexing& loop_indexing) { + // Apply loop swizzles if there are any that outputs to + // the loop domains. + // Currently only support loop swizzles that directly output + // to concrete loop domains and these are validated in + // validate swizzle pass. + // TODO: + // will gradually enable replaying and mapping of loop + // swizzles in the IR infrastructure and once that's piped + // through this part of logic will be removed. + std::unordered_set visited; + for (auto loop_id : loop_indexing.loopDomains()) { + auto loop_id_def = loop_id->definition(); + if (loop_id_def != nullptr && loop_id_def->isA()) { + if (visited.insert(loop_id_def).second) { + handle(loop_id_def); + } + } + } + + // Run through the loop indexing expressions and generate + // the indexing integer math for the concrete ids. for (auto expr : loop_indexing.getBackwardExprList()) { handle(expr); } @@ -955,6 +1007,7 @@ void IndexSwizzle::run() { UpdateLeafIndices update_leaves(td_, indexMap(), extentMap()); index_map_ = update_leaves.indexMap(); extent_map_ = update_leaves.extentMap(); + IndexCompute::swizzle_mode_ = SwizzleMode::Data; IndexCompute::run(); } } @@ -969,7 +1022,8 @@ void IndexSwizzle::handle(Expr* e) { return swizzled_ids_.find(id) != swizzled_ids_.end(); }) || (e->isA() && - e->as()->swizzleType() != Swizzle2DType::NoSwizzle); + e->as()->swizzleType() != Swizzle2DType::NoSwizzle && + e->as()->swizzleMode() == SwizzleMode::Data); if (!needs_update) { return; } @@ -983,8 +1037,6 @@ void IndexSwizzle::handle(Expr* e) { void IndexSwizzle::handle(Swizzle2D* swizzle_2d) { auto out_x_id = swizzle_2d->outX(); auto out_y_id = swizzle_2d->outY(); - auto in_x_id = swizzle_2d->inX(); - auto in_y_id = swizzle_2d->inY(); auto out_x_it = index_map_.find(out_x_id); auto out_y_it = index_map_.find(out_y_id); @@ -998,28 +1050,7 @@ void IndexSwizzle::handle(Swizzle2D* swizzle_2d) { out_x_it != index_map_.end() && out_y_it != index_map_.end(), "Swizzle output indices were not propagated through"); - const auto out_x_ind = out_x_it->second; - const auto out_y_ind = out_y_it->second; - - // Can propagate zero only for a few - // swizzle types (TODO) - - if (swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle) { - auto out_pair = IrBuilder::swizzle2DIntExpr( - out_x_ind, - out_y_ind, - getExtent(out_x_id), - getExtent(out_y_id), - swizzle_2d->swizzleType()); - - index_map_[in_x_id] = - IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X); - index_map_[in_y_id] = - IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y); - - swizzled_ids_.insert(in_x_id); - swizzled_ids_.insert(in_y_id); - } + IndexCompute::handle(swizzle_2d); } // Used for local and shared index mapping. Returns a map from loops diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 629e8928bac5d..43964e39feb80 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -130,6 +130,13 @@ class IndexCompute : public BackwardVisitor { // map rather than the actual IDs used in the ID expressions. bool concrete_id_pass_ = false; + // Mode of swizzle that are activated in this index compute + // instance. Will treat swizzles of different mode as no-op. + // Currently data mode swizzles are handled same as before in IndexSwizzle + // pass, while loop mode swizzles are handled early on in concrete indexing + // pass. See also [Note on swizzle mode] + SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle; + public: const std::unordered_map& indexMap() const { return index_map_; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index fe5fd9143c7ee..73804c8d58f58 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -381,7 +381,11 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! Swizzle the rectangular tile defined by the iterdomains corresponding //! to the 2 given indices. - TensorView* swizzle(Swizzle2DType swizzle_type, int x, int y); + TensorView* swizzle( + Swizzle2DType swizzle_type, + int x, + int y, + SwizzleMode swizzle_mode = SwizzleMode::Data); // WARNING: rFactor does not return this TensorView, ir returns a new // tensorview consumed by this! diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 467c949bc9bbc..4a594728fb5a8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -991,7 +991,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val { static std::pair swizzle( Swizzle2DType swizzle_type, IterDomain* in_x, - IterDomain* in_y); + IterDomain* in_y, + SwizzleMode swizzle_mode = SwizzleMode::Data); bool isMmaSwizzled() const { return is_mma_swizzled_; @@ -1198,7 +1199,11 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { //! Applies 2D swizzle on a rectangular tile defined by //! a pair of iterdomains contained in this domain. - void swizzle(Swizzle2DType swizzle_type, int x, int y); + void swizzle( + Swizzle2DType swizzle_type, + int x, + int y, + SwizzleMode swizzle_mode = SwizzleMode::Data); // Transform TensorView according to merge and split transformations TensorDomain* view( @@ -1339,7 +1344,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { IterDomain* out_y, IterDomain* in_x, IterDomain* in_y, - Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle); + Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle, + SwizzleMode swizzle_mode = SwizzleMode::Data); Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); @@ -1359,10 +1365,14 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { return in_y_; } - const auto& swizzleType() const { + auto swizzleType() const { return swizzle_type_; } + auto swizzleMode() const { + return swizzle_mode_; + } + bool sameAs(const Statement* other) const override; private: @@ -1377,7 +1387,50 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { // The type of predefined 1-to-1 functions // used for swizzling math. - Swizzle2DType swizzle_type_; + Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle; + + // Swizzle mode of this swizzle instance. + // [Note on swizzle mode] + // On the current implementations we support two modes of + // swizzle math, namely, data mode and loop mode. + // `Data` mode swizzling is a swizzle that will change the + // data layout in shared memory, likely in global memory buffers + // as well in the future. see also IndexSwizzle in index_compute.cpp. + // + // Most important use cases are transpose bank conflict removal, and mma + // swizzled shared memory layout. Example illustrated in 1D case: + // + // for (int i = 0; i IterDomain::stridedSplit(int factor) { std::pair IterDomain::swizzle( Swizzle2DType swizzle_type, IterDomain* in_x, - IterDomain* in_y) { + IterDomain* in_y, + SwizzleMode swizzle_mode) { TORCH_CHECK( !in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(), "Invalid swizzling of a empty dimension."); @@ -1319,7 +1320,7 @@ std::pair IterDomain::swizzle( IterDomain* out_y = IterDomainBuilder(in_y).build(); IrBuilder::create( - in_x->container(), out_x, out_y, in_x, in_y, swizzle_type); + in_x->container(), out_x, out_y, in_x, in_y, swizzle_type, swizzle_mode); return std::make_pair(out_x, out_y); } @@ -1790,7 +1791,11 @@ std::vector TensorDomain::orderedAs( return reordered_domain; } -void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) { +void TensorDomain::swizzle( + Swizzle2DType swizzle_type, + int x, + int y, + SwizzleMode swizzle_mode) { TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain"); TORCH_CHECK( @@ -1808,7 +1813,7 @@ void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) { IterDomain* axis_out_y = nullptr; std::tie(axis_out_x, axis_out_y) = - IterDomain::swizzle(swizzle_type, axis_x, axis_y); + IterDomain::swizzle(swizzle_type, axis_x, axis_y, swizzle_mode); domain_.erase(domain_.begin() + x); domain_.insert(domain_.begin() + x, axis_out_x); @@ -2039,13 +2044,15 @@ Swizzle2D::Swizzle2D( IterDomain* out_y, IterDomain* in_x, IterDomain* in_y, - Swizzle2DType swizzle_type) + Swizzle2DType swizzle_type, + SwizzleMode swizzle_mode) : Expr(passkey, ExprType::Swizzle2D), out_x_{out_x}, out_y_{out_y}, in_x_{in_x}, in_y_{in_y}, - swizzle_type_(swizzle_type) { + swizzle_type_(swizzle_type), + swizzle_mode_(swizzle_mode) { addOutput(out_x); addOutput(out_y); addInput(in_x); @@ -2071,7 +2078,8 @@ Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner) out_y_(ir_cloner->clone(src->out_y_)), in_x_(ir_cloner->clone(src->in_x_)), in_y_(ir_cloner->clone(src->in_y_)), - swizzle_type_(src->swizzle_type_) {} + swizzle_type_(src->swizzle_type_), + swizzle_mode_(src->swizzle_mode_) {} NamedScalar::NamedScalar( IrBuilderPasskey passkey, diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 01dae72916ee7..faa296732ca0d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -848,7 +848,7 @@ struct ReplaceValInIndexVal : public OptInDispatch { void handle(Val* val) override { TORCH_INTERNAL_ASSERT( - val->isA() || val->isA(), + val->isA() || val->isA() || val->isA(), "Invalid Val type: ", val->toString()); @@ -864,11 +864,17 @@ struct ReplaceValInIndexVal : public OptInDispatch { // Recursively traverse its defining expr auto def = val->definition(); if (def != nullptr) { - TORCH_INTERNAL_ASSERT( - def->isA() || def->isA(), - "Unexpected definition: ", - def->toString()); - handle(val->definition()); + switch (def->etype()) { + case ExprType::UnaryOp: + case ExprType::BinaryOp: + case ExprType::Swizzle2DInt: + case ExprType::PairSelect: + handle(val->definition()); + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Unexpected definition: ", def->toString()) + } // last_visited_val_ is set in the expr handlers } else { last_visited_val_ = val; @@ -897,6 +903,36 @@ struct ReplaceValInIndexVal : public OptInDispatch { last_visited_val_ = out; } + // Clone expression after recurisvely replacing inputs + void handle(kir::Swizzle2DInt* swizzle_2d) override { + handle(swizzle_2d->inX()); + auto in_x = last_visited_val_; + handle(swizzle_2d->inY()); + auto in_y = last_visited_val_; + auto out = IrBuilder::create(); + + // Extents are assumed constant in swizzle so no need to + // duplicate their graphs. + IrBuilder::create( + out, + in_x, + in_y, + swizzle_2d->extentX(), + swizzle_2d->extentY(), + swizzle_2d->swizzleType()); + last_visited_val_ = out; + } + + void handle(kir::PairSelect* pair_select) override { + handle(pair_select->in()->asVal()); + auto in = last_visited_val_; + TORCH_INTERNAL_ASSERT(pair_select->out()->isA()); + auto out = IrBuilder::create(c10::nullopt); + IrBuilder::create( + out, in->as(), pair_select->selection()); + last_visited_val_ = out; + } + private: const std::unordered_map& replacement_map_; Val* last_visited_val_ = nullptr; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index b8258da0b2342..b6e39fa588e8a 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -257,6 +257,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // Validate mma data format and compatibility if any on the fusion. validateMma(fusion_); + // Validate swizzle usage on the fusion schedule. + validateSwizzle(fusion_); + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 43524cdfd3058..f4a31cba13498 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1059,17 +1059,76 @@ void validateMma(Fusion* fusion) { } } +namespace { + +// Utility function to validate a loop swizzle: +// 1. Throws an error if any output of the swizzle is not in leaf_domain set. +// 2. Warns if any output of the swizzle is not the concrete id of the loop +// map. +// The second case would make the codegen ignore this swizzle, as if it was not +// there at all. +void validateLoopSwizzle( + Expr* swizzle_expr, + std::unordered_set& leaf_domains) { + for (auto out_id : + ir_utils::filterByType(swizzle_expr->outputs())) { + TORCH_INTERNAL_ASSERT( + leaf_domains.count(out_id), + "Loop swizzle can only be direct producer of leaf domains."); + if (GpuLower::current()->caMap()->getConcreteMappedID( + out_id, IdMappingMode::LOOP) != out_id) { + TORCH_WARN_ONCE("Ignored loop swizzle :", swizzle_expr->toString()); + } + } +} + +} // namespace + void validateSwizzle(Fusion* fusion) { auto used_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(used_vals)) { if (tv->hasSwizzleOp()) { + std::unordered_set tv_leaf_domain_set( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + // Make sure no swizzle op is inlined: auto inlined_swizzles = ir_utils::getAllSwizzlesBetween( tv->getMaybeRFactorDomain(), {tv->domain()->domain().begin(), tv->domain()->domain().begin() + tv->getComputeAtPosition()}); - TORCH_INTERNAL_ASSERT( - inlined_swizzles.empty(), "No support for inlined swizzles"); + + auto not_inlined_swizzles = ir_utils::getAllSwizzlesBetween( + tv->getMaybeRFactorDomain(), + {tv->domain()->domain().begin() + tv->getComputeAtPosition(), + tv->domain()->domain().end()}); + + // Check inlined swizzles: only loop swizzles can be inlined currently + // as inlining data swizzles would require addtional support of unswizzle + // operator, which currently doesn't have important use cases. + for (auto swizzle_expr : inlined_swizzles) { + TORCH_INTERNAL_ASSERT( + swizzle_expr->as()->swizzleMode() == SwizzleMode::Loop, + "Only support inlining loop swizzles"); + validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set); + } + + std::unordered_set inlined_swizzle_set( + inlined_swizzles.begin(), inlined_swizzles.end()); + + // Check not inlined swizzles: + // Apply the loop swizzle check when it applies, and + // also make sure that the no swizzle is also in inlined_swizzle set. + // The latter would mean that one output of the swizzle is inlined while + // the other is not. Such case will not be supported. + for (auto swizzle_expr : not_inlined_swizzles) { + TORCH_INTERNAL_ASSERT( + !inlined_swizzle_set.count(swizzle_expr), + "Cannot partially inline across swizzle domains.", + swizzle_expr->toString()); + if (swizzle_expr->as()->swizzleMode() == SwizzleMode::Loop) { + validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set); + } + } } } } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 8a0d61c878fbd..ca3abc75aabdd 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -224,6 +224,10 @@ void scheduleMatmul( // and needs more configurability. // ------------------------------------------------------------------ // CTA tile: + + // Swizzle block tiles: + c->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop); + a->computeAt(c, 2); b->computeAt(c, 2); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 976bc54f0d444..737c2fb610136 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -573,7 +573,11 @@ TensorView* TensorView::swizzle( return this; } -TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) { +TensorView* TensorView::swizzle( + Swizzle2DType swizzle_type, + int x, + int y, + SwizzleMode swizzle_mode) { has_swizzle_op_ = true; if (x < 0) { x += domain()->nDims(); @@ -647,7 +651,7 @@ TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) { } } - domain()->swizzle(swizzle_type, x, y); + domain()->swizzle(swizzle_type, x, y, swizzle_mode); return this; } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index a4f48c7855107..dfafa12495084 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23871,6 +23871,136 @@ TEST_F(NVFuserTest, FusionSwizzleMapping_CUDA) { tv1->axis(-1), swizzle_op->outY(), IdMappingMode::PERMISSIVE)); } +// Test a basic loop swizzle pattern +TEST_F(NVFuserTest, FusionLoopSwizzle0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv2->split(-1, 16); + tv2->split(-1, 4); + //[O, 4, 4] + + tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); + + tv0->computeAt(tv2, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 32}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Outer block zshape pattern +TEST_F(NVFuserTest, FusionLoopSwizzle1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv2->split(-2, 8); + tv2->split(-1, 4); + //[I0o, I0i, I1o, I1i] + tv2->reorder({{1, 2}, {2, 1}}); + //[I0o, I1o, I0i, I1i] + + tv2->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop); + tv0->computeAt(tv2, -1); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({45, 77}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Test assertion in unsupported pattern: non-leaf loop swizzle. +TEST_F(NVFuserTest, FusionLoopSwizzleCheck0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv2->split(-1, 16); + tv2->split(-1, 4); + //[O, 4, 4] + + // Swizzle the inner tile. + tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); + + // Make swizzle output not a leaf domain. + tv2->merge(-2); + + tv0->computeAt(tv2, -1); + + FusionExecutor fe; + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +// Test assertion in unsupported pattern: half-inlined loop swizzle. +TEST_F(NVFuserTest, FusionLoopSwizzleCheck1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + //[O, 4, 4] + tv2->split(-1, 16); + tv2->split(-1, 4); + + //[O, 4, 4] + tv3->split(-1, 16); + tv3->split(-1, 4); + + // Swizzle inner tile of tv2 + tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); + + // Make tv2 swizzled and half-inlined (unsupported). + tv0->computeAt(tv3, -2); + + fusion.print(); + FusionExecutor fe; + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + TEST_F(NVFuserTest, FusionUnsqueeze1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 339ebeee19b17..c00d02c8a40dd 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -78,6 +78,80 @@ bool cudaArchGuardShouldSkip(int required_major, int required_minor) { COMPILE_FUSION; \ } +// util to track support matmul operand layout. +using MatmulLayout = MmaOptions::MmaInputLayout; + +static constexpr std::array kAllSupportedLayout = { + MatmulLayout::TT, + MatmulLayout::NT, + MatmulLayout::TN}; + +// Generic interface to get matmul op with the given layout. +TensorView* matmul(TensorView* a, TensorView* b, MatmulLayout layout) { + TORCH_CHECK( + a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests"); + TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr; + switch (layout) { + case MatmulLayout::TT: + tv0b = broadcast(a, {false, false, true}); + tv1b = broadcast(b, {true, false, false}); + tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + break; + case MatmulLayout::TN: + tv0b = broadcast(a, {false, true, false}); + tv1b = broadcast(b, {true, false, false}); + tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + break; + case MatmulLayout::NT: + tv0b = broadcast(a, {false, false, true}); + tv1b = broadcast(b, {false, true, false}); + tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + break; + default: + TORCH_CHECK(false, "unsupported data layout."); + } + return tv2; +} + +// Utility to generate matmul input tensors based on given layout +at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { + switch (layout) { + case MatmulLayout::TT: + return a.matmul(b); + case MatmulLayout::TN: + return a.matmul(b.t()); + case MatmulLayout::NT: + return a.t().matmul(b); + default: + TORCH_CHECK(false, "unsupported data layout."); + } + return at::Tensor(); +} + +// Utility to generate reference results based on given layout +std::pair fp16MatmulAtInput( + int M, + int N, + int K, + MatmulLayout layout) { + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + + switch (layout) { + case MatmulLayout::TT: + return std::make_pair( + at::randn({M, K}, options), at::randn({K, N}, options)); + case MatmulLayout::TN: + return std::make_pair( + at::randn({M, K}, options), at::randn({N, K}, options)); + case MatmulLayout::NT: + return std::make_pair( + at::randn({K, M}, options), at::randn({K, N}, options)); + default: + TORCH_CHECK(false, "unsupported data layout."); + } + return std::make_pair(at::Tensor(), at::Tensor()); +} + #define REQUIRE_DEVICE_SMEM_SIZE(required_size, device_idx) \ if (at::cuda::getDeviceProperties(device_idx)->sharedMemPerBlockOptin < \ required_size) { \ @@ -315,310 +389,93 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// Gemm test for Volta MMA: TT -// This is the only example that is fully manual, -// the rest of them are facilitated by gemm utils. -TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - +// Matmul test for Volta MMA: across supported layouts +TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { // Keep multiples of 8 to keep vectorizable. - int M = 264, N = 120, K = 248; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 0, fe.compileFusion(&fusion, {t0, t1})); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Gemm test for Volta MMA: TN -TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 120, N = 264, K = 56; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 0, fe.compileFusion(&fusion, {t0, t1})); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Gemm test for Volta MMA: NT -TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 240, N = 320, K = 136; - - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); + int M = 264, N = 136, K = 248; - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addOutput(tv2); + fusion.addInput(tv0); + fusion.addInput(tv1); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); + auto tv2 = matmul(tv0, tv1, layout); - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + fusion.addOutput(tv2); - mma_builder.configureMma(tv2); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(layout); - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 0, fe.compileFusion(&fusion, {t0, t1})); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).t().matmul(t1.to(at::kFloat)); + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } } -TEST_F(NVFuserTest, FusionVoltaMatMulTTRegDoubleBuffer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - +// Matmul test for Volta MMA: across supported layouts +TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { // Keep multiples of 8 to keep vectorizable. - int M = 264, N = 120, K = 248; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); - - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 0, fe.compileFusion(&fusion, {t0, t1})); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Gemm test for Volta MMA: TN -TEST_F(NVFuserTest, FusionVoltaMatMulTNRegDoubleBuffer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 120, N = 264, K = 56; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); - - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 0, fe.compileFusion(&fusion, {t0, t1})); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Gemm test for Volta MMA: NT -TEST_F(NVFuserTest, FusionVoltaMatMulNTRegDoubleBuffer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 240, N = 320, K = 136; - - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); + int M = 264, N = 136, K = 248; - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addOutput(tv2); + fusion.addInput(tv0); + fusion.addInput(tv1); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); + auto tv2 = matmul(tv0, tv1, layout); - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + fusion.addOutput(tv2); - mma_builder.configureMma(tv2); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(layout); - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.double_buffer_options.double_buffer_smem_read = true; + scheduleMatmul(tv2, tv0, tv1, params); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 0, fe.compileFusion(&fusion, {t0, t1})); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).t().matmul(t1.to(at::kFloat)); + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } } // MMA unit test on Ampere @@ -848,592 +705,550 @@ TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulTN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); +// Matmul test for Ampere MMA: across supported layouts +TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = 4; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } +} - Fusion fusion; - FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; +// Matmul test for Ampere MMA: with pipelined gmem load +TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - // Call scheduler - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); + // Gmem pipeline stage + for (auto stage : {3, 4}) { + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = stage; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } + } +} - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); +TEST_F(NVFuserTest, FusionAmpereMatmulRegDbouleBuffer_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + // Gmem pipeline stage + for (auto stage : {3, 4}) { + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = stage; + params.double_buffer_options.double_buffer_smem_read = true; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } + } } -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulTT_CUDA) { +// Matmul-Matmul fusion test on Ampere +TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; + int M = 512, N = 256, K1 = 128, K2 = 128; + + // Fusion definition (Both gemms are TN) + // [M,K1] + auto tv0 = makeConcreteTensor({M, K1}, DataType::Half); + // [K2,K1] + auto tv1 = makeConcreteTensor({K2, K1}, DataType::Half); + // [N,K2] + auto tv2 = makeConcreteTensor({N, K2}, DataType::Half); - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); + fusion.addInput(tv2); - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); auto tv1b = broadcast(tv1, {true, false, false}); + auto tv2b = broadcast(tv2, {true, false, false}); - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); + // [M,K2,R] + auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + auto tv3h = castOp(DataType::Half, tv3); + auto tv3b = broadcast(tv3h, {false, true, false}); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); + auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + fusion.addOutput(tv4); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); + // Fusion: + // Gemm(M,K2,K1) x Gemm(M,N,K2) - FusionExecutor fe; + MatMulTileOptions gemm_tile1, gemm_tile2; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); + // cta tile: + // To save register, n of cta tile 1 + // matches k of cta tile2 + gemm_tile1.cta_tile = GemmTile(128, 64, 32); + gemm_tile2.cta_tile = GemmTile(128, 32, 64); - auto cg_outputs = fe.runFusion({t0, t1}); + // Distribute to 2x2 warps + gemm_tile1.warp_tile = GemmTile(64, 32, 32); + gemm_tile2.warp_tile = GemmTile(64, 16, 64); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + // Using Ampere mma macro + gemm_tile2.instruction_tile = GemmTile(16, 8, 16); + gemm_tile2.instruction_tile = GemmTile(16, 8, 16); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} + auto mma_builder1 = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile1) + .layout(MmaOptions::MmaInputLayout::TN); -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulNT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + auto mma_builder2 = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2) + .layout(MmaOptions::MmaInputLayout::TN); - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; + mma_builder1.configureMma(tv3); + mma_builder2.configureMma(tv4); - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + // Global read for gemm 1 + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); + // Global read for gemm 2 + auto tv2r = tv2->cacheAfter(); - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + // Gemm 1 main loop read + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - fusion.addOutput(tv2); + // Gemm 1 accumulator reg + auto tv3c = tv3->cacheBefore(); + mma_builder1.accumulatorTv(tv3c); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + // Gemm 2 main loop read + auto tv3cw = tv3h->cacheAfter(); + auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + auto tv2cw = tv2r->cacheAfter(); + auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + // Gemm 2 accumulator reg + auto tv4c = tv4->cacheBefore(); + mma_builder2.accumulatorTv(tv4c); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + // General idea is inlining gemm1's main loop inside gemm2's - FusionExecutor fe; + // Schedule gemm 2: + // ------------------------------------------------------------------ + tv4->split(-2, gemm_tile2.cta_tile.m); + tv4->split(-1, gemm_tile2.cta_tile.n); - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); + // 0 1 2 3 + // [Mo,M128, No, N128] + tv4->reorder({{1, 2}, {2, 1}}); - auto cg_outputs = fe.runFusion({t0, t1}); + // 0 1 2 3 + // [Mo,No, M128, N128] + tv2->computeAt(tv4, 2); + tv3->computeAt(tv4, 2); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv4c->split(-1, gemm_tile2.cta_tile.k); + tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1 + tv2r->computeAt(tv4c, 3); -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulTNPipelined_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - // Require 70KB of Shared mem on device 0 to run - REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); + // Make warp tile + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( + tv4c, gemm_tile2); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv4, gemm_tile2); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv3cr->computeAt(tv4c, -4); + tv2cr->computeAt(tv4c, -4); - Fusion fusion; - FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- + // [No,Ko,N,K] + tv2cw->merge(-2); + tv2r->merge(-2); - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv2cw, gemm_tile2, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv2r, gemm_tile2, 8); + tv2cw->setMemoryType(MemoryType::Shared); - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + // Schedule gemm 2 mma input + // --------------------------------------------------------------------------- + tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - fusion.addOutput(tv2); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv3b->reorder({{-2, -3}, {-3, -2}}); + tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); + tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); + // Schedule mma output + // --------------------------------------------------------------------------- + tv4c->applyMmaSwizzle( + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); + tv4->applyMmaSwizzle( + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); - // Call scheduler - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = 4; - scheduleMatmul(tv2, tv0, tv1, params); + // Schedule gemm 1: + // ------------------------------------------------------------------ - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); + // CTA tile: + tv0->computeAt(tv3, 2); + tv1->computeAt(tv3, 2); - FusionExecutor fe; + // Schedule K dim for gemm 1: - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv3c->split(-1, gemm_tile1.cta_tile.k); + tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv3c, 3); + tv1r->computeAt(tv3c, 3); - auto cg_outputs = fe.runFusion({t0, t1}); + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( + tv3c, gemm_tile1); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv3cw, gemm_tile1); - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + tv0cr->computeAt(tv3c, -4); + tv1cr->computeAt(tv3c, -4); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} + tv3->computeAt(tv3cw, -3); -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulTTPipelined_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - // Require 70KB of Shared mem on device 0 to run - REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile1, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile1, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 120; + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile1, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile1, 8); + tv1cw->setMemoryType(MemoryType::Shared); - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); + tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + // Schedule mma output + // --------------------------------------------------------------------------- + tv3c->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); + tv3cw->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); + tv3h->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); + tv3->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); + tv3cw->setMemoryType(MemoryType::Shared); - fusion.addOutput(tv2); + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 1 + tv3c->axis(3)->parallelize(ParallelType::TIDz); + tv3c->axis(4)->parallelize(ParallelType::TIDy); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + tv3->computeAt(tv3cw, -2); + tv3cw->axis(2)->parallelize(ParallelType::TIDz); + tv3cw->axis(3)->parallelize(ParallelType::TIDy); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); + // Gemm 2 + tv4->axis(2)->parallelize(ParallelType::TIDz); + tv4->axis(3)->parallelize(ParallelType::TIDy); + tv4c->axis(3)->parallelize(ParallelType::TIDz); + tv4c->axis(4)->parallelize(ParallelType::TIDy); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = 3; - scheduleMatmul(tv2, tv0, tv1, params); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::BIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); + auto t0 = at::randn({M, K1}, options); + auto t1 = at::randn({K2, K1}, options); + auto t2 = at::randn({N, K2}, options); + + auto tref = t0.to(at::kFloat) + .matmul(t1.t().to(at::kFloat)) + .matmul(t2.t().to(at::kFloat)); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); + 8, 0, fe.compileFusion(&fusion, {t0, t1, t2})); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + auto cg_outputs = fe.runFusion({t0, t1, t2}); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + // relaxed check for now, err accumulation is significant. + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulNTPipelined_CUDA) { +// Simplified Matmul-Softmax-Matmul test on Ampere +// (To be extended in follow ups) +TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - // Require 70KB of Shared mem on device 0 to run - REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); Fusion fusion; FusionGuard fg(&fusion); - int M = 512, N = 256, K = 136; - - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + // Omitting outer dimensions and pointwise ops - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = 3; - scheduleMatmul(tv2, tv0, tv1, params); + const int seql_q = 32; + const int seql_k = 128; + const int hidden_size = 1024; + const int num_heads = 16; + const int head_dim = hidden_size / num_heads; - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + // Gemm 1: + // (80, 80, 64) + const int M1 = seql_q, N1 = seql_k, K1 = head_dim; + // (80, 64, 80) + const int M2 = seql_q, N2 = head_dim, K2 = seql_k; - FusionExecutor fe; + // Fusion definition (Both gemms are TN) + // [M,K1] + auto inp = makeConcreteTensor({M1, K1}, DataType::Half); + // Query matrix + auto qk = makeConcreteTensor({N1, K1}, DataType::Half); + // Second linear matrix + auto acc = makeConcreteTensor({N2, K2}, DataType::Half); - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); + fusion.addInput(inp); + fusion.addInput(qk); + fusion.addInput(acc); - auto cg_outputs = fe.runFusion({t0, t1}); + // [M,N,K] + auto tv0b = broadcast(inp, {false, true, false}); + auto tv1b = broadcast(qk, {true, false, false}); + auto tv2b = broadcast(acc, {true, false, false}); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + // [M,K2,R] + auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} + // Inline define softmax for now for scheduling + auto x = tv3; + const int kReductionAxis = 1; + const int kNumberOfDims = 2; + std::vector broadcast_mask(kNumberOfDims, false); + broadcast_mask[kReductionAxis] = true; -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulNTRegDoubleBuffer_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + auto max_val = max(x, {kReductionAxis}); + auto bcast_max = broadcast(max_val, broadcast_mask); + auto x_max_sub = sub(x, bcast_max); + auto exp_val = exp(x_max_sub); + auto sum_exp = sum(exp_val, {kReductionAxis}); + auto bcast_sum = broadcast(sum_exp, broadcast_mask); + auto recip = reciprocal(bcast_sum); + auto tv3sfm = mul(exp_val, recip); - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; + auto tv3h = castOp(DataType::Half, tv3sfm); + auto tv3b = broadcast(tv3h, {false, true, false}); + auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + fusion.addOutput(tv4); - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); + // Fusion: + // Gemm(M,K2,K1) x Gemm(M,N,K2) + MatMulTileOptions gemm_tile; - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + // TODO: use very small tiles for now since + // alias pass is not re-using smem. Fix later. + gemm_tile.cta_tile = GemmTile(32, 128, 32); - fusion.addOutput(tv2); + // Distribute to 2x2 warps + gemm_tile.warp_tile = GemmTile(16, 64, 32); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); + // Using Ampere mma macro gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = + auto mma_builder1 = MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + .layout(MmaOptions::MmaInputLayout::TN); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); + auto mma_builder2 = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + mma_builder1.configureMma(tv3); + mma_builder2.configureMma(tv4); - FusionExecutor fe; + // Global read for gemm 1 + auto tv0r = inp->cacheAfter(); + auto tv1r = qk->cacheAfter(); - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); + // Global read for gemm 2 + auto tv2r = acc->cacheAfter(); - auto cg_outputs = fe.runFusion({t0, t1}); + // Gemm 1 main loop read + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + // Gemm 1 accumulator reg + auto tv3c = tv3->cacheBefore(); + mma_builder1.accumulatorTv(tv3c); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulTNRegDoubleBuffer_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - // Call scheduler - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Ampere -TEST_F(NVFuserTest, FusionAmpereMatmulTTRegDoubleBuffer_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul-Matmul fusion test on Ampere -TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K1 = 128, K2 = 128; - - // Fusion definition (Both gemms are TN) - // [M,K1] - auto tv0 = makeConcreteTensor({M, K1}, DataType::Half); - // [K2,K1] - auto tv1 = makeConcreteTensor({K2, K1}, DataType::Half); - // [N,K2] - auto tv2 = makeConcreteTensor({N, K2}, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - auto tv2b = broadcast(tv2, {true, false, false}); - - // [M,K2,R] - auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - - auto tv3h = castOp(DataType::Half, tv3); - auto tv3b = broadcast(tv3h, {false, true, false}); - - auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - - fusion.addOutput(tv4); - - // Fusion: - // Gemm(M,K2,K1) x Gemm(M,N,K2) - - MatMulTileOptions gemm_tile1, gemm_tile2; - - // cta tile: - // To save register, n of cta tile 1 - // matches k of cta tile2 - gemm_tile1.cta_tile = GemmTile(128, 64, 32); - gemm_tile2.cta_tile = GemmTile(128, 32, 64); - - // Distribute to 2x2 warps - gemm_tile1.warp_tile = GemmTile(64, 32, 32); - gemm_tile2.warp_tile = GemmTile(64, 16, 64); - - // Using Ampere mma macro - gemm_tile2.instruction_tile = GemmTile(16, 8, 16); - gemm_tile2.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder1 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile1) - .layout(MmaOptions::MmaInputLayout::TN); - - auto mma_builder2 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2) - .layout(MmaOptions::MmaInputLayout::TN); - - mma_builder1.configureMma(tv3); - mma_builder2.configureMma(tv4); - - // Global read for gemm 1 - auto tv0r = tv0->cacheAfter(); - auto tv1r = tv1->cacheAfter(); - - // Global read for gemm 2 - auto tv2r = tv2->cacheAfter(); - - // Gemm 1 main loop read - auto tv0cw = tv0r->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); + // Softmax conversion: + auto tv3ccr = tv3->cacheAfter(); - // Gemm 1 accumulator reg - auto tv3c = tv3->cacheBefore(); - mma_builder1.accumulatorTv(tv3c); + // tv3ccr -> tv3h : softmax // Gemm 2 main loop read - auto tv3cw = tv3h->cacheAfter(); - auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix); + // auto tv3cw = tv3h->cacheAfter(); + auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2cw = tv2r->cacheAfter(); auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); @@ -1442,12 +1257,10 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { auto tv4c = tv4->cacheBefore(); mma_builder2.accumulatorTv(tv4c); - // General idea is inlining gemm1's main loop inside gemm2's - // Schedule gemm 2: // ------------------------------------------------------------------ - tv4->split(-2, gemm_tile2.cta_tile.m); - tv4->split(-1, gemm_tile2.cta_tile.n); + tv4->split(-2, gemm_tile.cta_tile.m); + tv4->split(-1, gemm_tile.cta_tile.n); // 0 1 2 3 // [Mo,M128, No, N128] @@ -1455,25 +1268,24 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // 0 1 2 3 // [Mo,No, M128, N128] - tv2->computeAt(tv4, 2); + acc->computeAt(tv4, 2); tv3->computeAt(tv4, 2); // Order K // 0 1 2 3 4 5 // [Mo,No, M128, N128, Ko, K32] - tv4c->split(-1, gemm_tile2.cta_tile.k); + tv4c->split(-1, gemm_tile.cta_tile.k); tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); // 0 1 2 3 4 5 // [Mo,No, Ko M128, N128, K32] - tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1 + tv3->computeAt(tv4c, 2); tv2r->computeAt(tv4c, 3); // Make warp tile - scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( - tv4c, gemm_tile2); + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile); scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( - tv4, gemm_tile2); + tv4, gemm_tile); // -8 -7 -6 -5 -4 -3 -2 -1 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] tv3cr->computeAt(tv4c, -4); @@ -1487,9 +1299,9 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // [No,Ko,i,wy,wx,v] scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv2cw, gemm_tile2, 8); + tv2cw, gemm_tile, 8); scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv2r, gemm_tile2, 8); + tv2r, gemm_tile, 8); tv2cw->setMemoryType(MemoryType::Shared); // Schedule tv2 gmem read and smem write: @@ -1498,7 +1310,6 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Schedule gemm 2 mma input // --------------------------------------------------------------------------- tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] tv3b->reorder({{-2, -3}, {-3, -2}}); tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); @@ -1517,15 +1328,23 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // ------------------------------------------------------------------ // CTA tile: - tv0->computeAt(tv3, 2); - tv1->computeAt(tv3, 2); + // [Mo, Mi128, N80] + + tv3->split(-1, gemm_tile.cta_tile.n); + // [Mo, Mi128, No, Ni128] + + tv3->reorder({{1, 2}, {2, 1}}); + + // [Mo, No, Mi128, Ni128] + inp->computeAt(tv3, 2); + qk->computeAt(tv3, 2); // Schedule K dim for gemm 1: // Order K // 0 1 2 3 4 5 // [Mo,No, M128, N128, Ko, K32] - tv3c->split(-1, gemm_tile1.cta_tile.k); + tv3c->split(-1, gemm_tile.cta_tile.k); tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); // 0 1 2 3 4 5 // [Mo,No, Ko M128, N128, K32] @@ -1534,15 +1353,14 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Make warp tile: // ------------------------------------------------------------------------- - scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( - tv3c, gemm_tile1); + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile); scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( - tv3cw, gemm_tile1); + tv3, gemm_tile); tv0cr->computeAt(tv3c, -4); tv1cr->computeAt(tv3c, -4); - tv3->computeAt(tv3cw, -3); + // tv3->computeAt(tv3cw,-3); // Schedule gmem read and smem write: // --------------------------------------------------------------------------- @@ -1550,9 +1368,9 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { tv0cw->merge(-2); tv0r->merge(-2); scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv0cw, gemm_tile1, 8); + tv0cw, gemm_tile, 8); scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv0r, gemm_tile1, 8); + tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] @@ -1561,9 +1379,9 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { tv1r->merge(-2); // [No,Ko,i,wy,wx,v] scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv1cw, gemm_tile1, 8); + tv1cw, gemm_tile, 8); scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv1r, gemm_tile1, 8); + tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input @@ -1576,827 +1394,139 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - // Schedule mma output + // // Schedule mma output + // // // --------------------------------------------------------------------------- tv3c->applyMmaSwizzle( mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3cw->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3h->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3->applyMmaSwizzle( mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3cw->setMemoryType(MemoryType::Shared); - // Parallelize - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - // Gemm 1 - tv3c->axis(3)->parallelize(ParallelType::TIDz); - tv3c->axis(4)->parallelize(ParallelType::TIDy); + // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, + // mma_builder1.build()); - tv3->computeAt(tv3cw, -2); - tv3cw->axis(2)->parallelize(ParallelType::TIDz); - tv3cw->axis(3)->parallelize(ParallelType::TIDy); + // Put tv3 result in smem + tv3->setMemoryType(MemoryType::Shared); - // Gemm 2 - tv4->axis(2)->parallelize(ParallelType::TIDz); - tv4->axis(3)->parallelize(ParallelType::TIDy); - tv4c->axis(3)->parallelize(ParallelType::TIDz); - tv4c->axis(4)->parallelize(ParallelType::TIDy); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::BIDy); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K1}, options); - auto t1 = at::randn({K2, K1}, options); - auto t2 = at::randn({N, K2}, options); - - auto tref = t0.to(at::kFloat) - .matmul(t1.t().to(at::kFloat)) - .matmul(t2.t().to(at::kFloat)); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1, t2})); - - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - // relaxed check for now, err accumulation is significant. - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); -} - -// Simplified Matmul-Softmax-Matmul test on Ampere -// (To be extended in follow ups) -TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - // Omitting outer dimensions and pointwise ops - - const int seql_q = 32; - const int seql_k = 128; - const int hidden_size = 1024; - const int num_heads = 16; - const int head_dim = hidden_size / num_heads; - - // Gemm 1: - // (80, 80, 64) - const int M1 = seql_q, N1 = seql_k, K1 = head_dim; - // (80, 64, 80) - const int M2 = seql_q, N2 = head_dim, K2 = seql_k; - - // Fusion definition (Both gemms are TN) - // [M,K1] - auto inp = makeConcreteTensor({M1, K1}, DataType::Half); - // Query matrix - auto qk = makeConcreteTensor({N1, K1}, DataType::Half); - // Second linear matrix - auto acc = makeConcreteTensor({N2, K2}, DataType::Half); - - fusion.addInput(inp); - fusion.addInput(qk); - fusion.addInput(acc); - - // [M,N,K] - auto tv0b = broadcast(inp, {false, true, false}); - auto tv1b = broadcast(qk, {true, false, false}); - auto tv2b = broadcast(acc, {true, false, false}); - - // [M,K2,R] - auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - - // Inline define softmax for now for scheduling - auto x = tv3; - const int kReductionAxis = 1; - const int kNumberOfDims = 2; - std::vector broadcast_mask(kNumberOfDims, false); - broadcast_mask[kReductionAxis] = true; - - auto max_val = max(x, {kReductionAxis}); - auto bcast_max = broadcast(max_val, broadcast_mask); - auto x_max_sub = sub(x, bcast_max); - auto exp_val = exp(x_max_sub); - auto sum_exp = sum(exp_val, {kReductionAxis}); - auto bcast_sum = broadcast(sum_exp, broadcast_mask); - auto recip = reciprocal(bcast_sum); - auto tv3sfm = mul(exp_val, recip); - - auto tv3h = castOp(DataType::Half, tv3sfm); - auto tv3b = broadcast(tv3h, {false, true, false}); - auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - - fusion.addOutput(tv4); - - // Fusion: - // Gemm(M,K2,K1) x Gemm(M,N,K2) - MatMulTileOptions gemm_tile; - - // TODO: use very small tiles for now since - // alias pass is not re-using smem. Fix later. - gemm_tile.cta_tile = GemmTile(32, 128, 32); - - // Distribute to 2x2 warps - gemm_tile.warp_tile = GemmTile(16, 64, 32); - - // Using Ampere mma macro - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder1 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - auto mma_builder2 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - mma_builder1.configureMma(tv3); - mma_builder2.configureMma(tv4); - - // Global read for gemm 1 - auto tv0r = inp->cacheAfter(); - auto tv1r = qk->cacheAfter(); - - // Global read for gemm 2 - auto tv2r = acc->cacheAfter(); - - // Gemm 1 main loop read - auto tv0cw = tv0r->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - - // Gemm 1 accumulator reg - auto tv3c = tv3->cacheBefore(); - mma_builder1.accumulatorTv(tv3c); - - // Softmax conversion: - auto tv3ccr = tv3->cacheAfter(); - - // tv3ccr -> tv3h : softmax - - // Gemm 2 main loop read - // auto tv3cw = tv3h->cacheAfter(); - auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix); - - auto tv2cw = tv2r->cacheAfter(); - auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); - - // Gemm 2 accumulator reg - auto tv4c = tv4->cacheBefore(); - mma_builder2.accumulatorTv(tv4c); - - // Schedule gemm 2: - // ------------------------------------------------------------------ - tv4->split(-2, gemm_tile.cta_tile.m); - tv4->split(-1, gemm_tile.cta_tile.n); - - // 0 1 2 3 - // [Mo,M128, No, N128] - tv4->reorder({{1, 2}, {2, 1}}); - - // 0 1 2 3 - // [Mo,No, M128, N128] - acc->computeAt(tv4, 2); - tv3->computeAt(tv4, 2); - - // Order K - // 0 1 2 3 4 5 - // [Mo,No, M128, N128, Ko, K32] - tv4c->split(-1, gemm_tile.cta_tile.k); - tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); - - // 0 1 2 3 4 5 - // [Mo,No, Ko M128, N128, K32] - tv3->computeAt(tv4c, 2); - tv2r->computeAt(tv4c, 3); - - // Make warp tile - scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile); - scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( - tv4, gemm_tile); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] - tv3cr->computeAt(tv4c, -4); - tv2cr->computeAt(tv4c, -4); - - // Schedule tv2 gmem read and smem write: - // ---------------------------------------------------------------- - // [No,Ko,N,K] - tv2cw->merge(-2); - tv2r->merge(-2); - - // [No,Ko,i,wy,wx,v] - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv2cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv2r, gemm_tile, 8); - tv2cw->setMemoryType(MemoryType::Shared); - - // Schedule tv2 gmem read and smem write: - // ---------------------------------------------------------------- - - // Schedule gemm 2 mma input - // --------------------------------------------------------------------------- - tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] - tv3b->reorder({{-2, -3}, {-3, -2}}); - tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - - tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); - tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); - - // Schedule mma output - // --------------------------------------------------------------------------- - tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); - tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); - - // Schedule gemm 1: - // ------------------------------------------------------------------ - - // CTA tile: - // [Mo, Mi128, N80] - - tv3->split(-1, gemm_tile.cta_tile.n); - // [Mo, Mi128, No, Ni128] - - tv3->reorder({{1, 2}, {2, 1}}); - - // [Mo, No, Mi128, Ni128] - inp->computeAt(tv3, 2); - qk->computeAt(tv3, 2); - - // Schedule K dim for gemm 1: - - // Order K - // 0 1 2 3 4 5 - // [Mo,No, M128, N128, Ko, K32] - tv3c->split(-1, gemm_tile.cta_tile.k); - tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); - // 0 1 2 3 4 5 - // [Mo,No, Ko M128, N128, K32] - tv0r->computeAt(tv3c, 3); - tv1r->computeAt(tv3c, 3); - - // Make warp tile: - // ------------------------------------------------------------------------- - scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile); - scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( - tv3, gemm_tile); - - tv0cr->computeAt(tv3c, -4); - tv1cr->computeAt(tv3c, -4); - - // tv3->computeAt(tv3cw,-3); - - // Schedule gmem read and smem write: - // --------------------------------------------------------------------------- - // [Mo,Ko,M,K] - tv0cw->merge(-2); - tv0r->merge(-2); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv0cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv0r, gemm_tile, 8); - tv0cw->setMemoryType(MemoryType::Shared); - // [Mo,Ko,i,wy,wx,v] - - // [No,Ko,N,K] - tv1cw->merge(-2); - tv1r->merge(-2); - // [No,Ko,i,wy,wx,v] - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv1cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv1r, gemm_tile, 8); - tv1cw->setMemoryType(MemoryType::Shared); - - // Schedule mma input - // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] - tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); - - tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - - // // Schedule mma output - // // - // --------------------------------------------------------------------------- - tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - - // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, - // mma_builder1.build()); - - // Put tv3 result in smem - tv3->setMemoryType(MemoryType::Shared); - - // schedule a reg persistent softmax: from tv3 - // [Mo, M128, RN] - max_val->split(-1, 128); - // [Mo, M128, RN1, RN128] - max_val->split(-1, 4); - // Map to warp (2x2) - max_val->split(-4, 4); - max_val->split(-4, 2); - - // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] - auto max_rf = max_val->rFactor({-1}); - // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - - // [Mo, M128, RN] - sum_exp->split(-1, 128); - // [Mo, M128, RN1, RN128] - sum_exp->split(-1, 4); - // Map to warp (2x2) - sum_exp->split(-4, 4); - sum_exp->split(-4, 2); - - // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] - auto sum_exp_rf = sum_exp->rFactor({-1}); - // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - - exp_val->computeAt(sum_exp_rf, 4); - exp_val->split(-1, 128); - exp_val->split(-1, 4); - bcast_max->computeAt(exp_val, -2); - - // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - - // Read from smem - tv3ccr->computeAt(max_rf, 4); - // [Mo, Mo32, My2, Mx2, N80] - tv3ccr->split(-1, 128); - tv3ccr->split(-1, 4); - // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - - // Write to second gemm - tv3h->split(-1, 128); - tv3h->split(-1, 4); - // Map to warp (2x2) - tv3h->split(-4, 4); - tv3h->split(-4, 2); - - bcast_sum->computeAt(tv3h, -2); - - tv3h->setMemoryType(MemoryType::Shared); - - // Parallelize - tv4->axis(0)->parallelize(ParallelType::BIDx); - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - // Gemm 1 - tv3c->axis(3)->parallelize(ParallelType::TIDz); - tv3c->axis(4)->parallelize(ParallelType::TIDy); - tv3->axis(2)->parallelize(ParallelType::TIDz); - tv3->axis(3)->parallelize(ParallelType::TIDy); - - auto parallelize_non_reduced_val = [](TensorView* tv) { - tv->axis(-2)->parallelize(ParallelType::TIDx); - tv->axis(2)->parallelize(ParallelType::TIDz); - tv->axis(3)->parallelize(ParallelType::TIDy); - }; - - auto parallelize_reduced_val = [](TensorView* tv) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - tv->axis(2)->parallelize(ParallelType::TIDz); - tv->axis(3)->parallelize(ParallelType::TIDy); - }; - - parallelize_non_reduced_val(tv3h); - parallelize_non_reduced_val(max_rf); - parallelize_non_reduced_val(bcast_max); - parallelize_non_reduced_val(exp_val); - parallelize_non_reduced_val(sum_exp_rf); - parallelize_non_reduced_val(bcast_sum); - parallelize_non_reduced_val(recip); - - parallelize_reduced_val(max_val); - parallelize_reduced_val(sum_exp); - - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - // Gemm 2 - tv4->axis(2)->parallelize(ParallelType::TIDz); - tv4->axis(3)->parallelize(ParallelType::TIDy); - tv4c->axis(3)->parallelize(ParallelType::TIDz); - tv4c->axis(4)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M1, K1}, options); - auto t1 = at::randn({N1, K1}, options); - auto t2 = at::randn({N2, K2}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1, t2})); - - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - auto sg1 = at::_softmax(g1, -1, false); - auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); -} - -// MMA unit test on Turing -TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [M,K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N,K] - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - mma_builder.configureMma(tv2); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); - - // [M,N,K] -> [N,M,K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// MMA unit test on Turing -TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [M,K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [K,N] - auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); - - mma_builder.configureMma(tv2); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); - - // [M,K,N] -> [N,M,K] - tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - // [M,K,N] -> [M,N,K] - tv1cr->reorder({{-2, -1}, {-1, -2}}); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - // [M,K,N] -> [M,N,K] - tv2c->reorder({{-2, -1}, {-1, -2}}); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({16, 8}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// MMA unit test on Turing -TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [K,M] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [K,N] - auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); - - mma_builder.configureMma(tv2); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); - - // [K,M,N] -> [N,M,K] - tv0cr->reorder({{-3, -1}, {-1, -3}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - // [K,M,N] -> [M,N,K] - tv1cr->reorder({ - {-3, -1}, - {-2, -3}, - {-1, -2}, - }); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - // [K,M,N] -> [M,N,K] - tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({16, 8}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// Matmul test on Turing -TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - mma_builder.configureMma(tv2); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Turing -TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); + // schedule a reg persistent softmax: from tv3 + // [Mo, M128, RN] + max_val->split(-1, 128); + // [Mo, M128, RN1, RN128] + max_val->split(-1, 4); + // Map to warp (2x2) + max_val->split(-4, 4); + max_val->split(-4, 2); - mma_builder.configureMma(tv2); + // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] + auto max_rf = max_val->rFactor({-1}); + // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + // [Mo, M128, RN] + sum_exp->split(-1, 128); + // [Mo, M128, RN1, RN128] + sum_exp->split(-1, 4); + // Map to warp (2x2) + sum_exp->split(-4, 4); + sum_exp->split(-4, 2); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); + // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] + auto sum_exp_rf = sum_exp->rFactor({-1}); + // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); + exp_val->computeAt(sum_exp_rf, 4); + exp_val->split(-1, 128); + exp_val->split(-1, 4); + bcast_max->computeAt(exp_val, -2); - auto cg_outputs = fe.runFusion({t0, t1}); + // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + // Read from smem + tv3ccr->computeAt(max_rf, 4); + // [Mo, Mo32, My2, Mx2, N80] + tv3ccr->split(-1, 128); + tv3ccr->split(-1, 4); + // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} + // Write to second gemm + tv3h->split(-1, 128); + tv3h->split(-1, 4); + // Map to warp (2x2) + tv3h->split(-4, 4); + tv3h->split(-4, 2); -// Matmul test on Turing -TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; + bcast_sum->computeAt(tv3h, -2); - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + tv3h->setMemoryType(MemoryType::Shared); - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); + // Parallelize + tv4->axis(0)->parallelize(ParallelType::BIDx); + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 1 + tv3c->axis(3)->parallelize(ParallelType::TIDz); + tv3c->axis(4)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDz); + tv3->axis(3)->parallelize(ParallelType::TIDy); - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + auto parallelize_non_reduced_val = [](TensorView* tv) { + tv->axis(-2)->parallelize(ParallelType::TIDx); + tv->axis(2)->parallelize(ParallelType::TIDz); + tv->axis(3)->parallelize(ParallelType::TIDy); + }; - fusion.addOutput(tv2); + auto parallelize_reduced_val = [](TensorView* tv) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + tv->axis(2)->parallelize(ParallelType::TIDz); + tv->axis(3)->parallelize(ParallelType::TIDy); + }; - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + parallelize_non_reduced_val(tv3h); + parallelize_non_reduced_val(max_rf); + parallelize_non_reduced_val(bcast_max); + parallelize_non_reduced_val(exp_val); + parallelize_non_reduced_val(sum_exp_rf); + parallelize_non_reduced_val(bcast_sum); + parallelize_non_reduced_val(recip); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + parallelize_reduced_val(max_val); + parallelize_reduced_val(sum_exp); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 2 + tv4->axis(2)->parallelize(ParallelType::TIDz); + tv4->axis(3)->parallelize(ParallelType::TIDy); + tv4c->axis(3)->parallelize(ParallelType::TIDz); + tv4c->axis(4)->parallelize(ParallelType::TIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + auto t0 = at::randn({M1, K1}, options); + auto t1 = at::randn({N1, K1}, options); + auto t2 = at::randn({N2, K2}, options); FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); + 8, 0, fe.compileFusion(&fusion, {t0, t1, t2})); - auto cg_outputs = fe.runFusion({t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1, t2}); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + auto sg1 = at::_softmax(g1, -1, false); + auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); } -// Matmul test on Turing -TEST_F(NVFuserTest, FusionTuringMatmulTNRegDoubleBuffer_CUDA) { +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); + auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); @@ -2411,8 +1541,8 @@ TEST_F(NVFuserTest, FusionTuringMatmulTNRegDoubleBuffer_CUDA) { fusion.addOutput(tv2); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = @@ -2421,13 +1551,31 @@ TEST_F(NVFuserTest, FusionTuringMatmulTNRegDoubleBuffer_CUDA) { mma_builder.configureMma(tv2); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + + // [M,N,K] -> [N,M,K] + tv0cr->reorder({{-2, -3}, {-3, -2}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({8, 16}, options); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2437,19 +1585,18 @@ TEST_F(NVFuserTest, FusionTuringMatmulTNRegDoubleBuffer_CUDA) { auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// Matmul test on Turing -TEST_F(NVFuserTest, FusionTuringMatmulTTRegDoubleBuffer_CUDA) { +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); @@ -2457,15 +1604,13 @@ TEST_F(NVFuserTest, FusionTuringMatmulTTRegDoubleBuffer_CUDA) { auto tv0b = broadcast(tv0, {false, false, true}); auto tv1b = broadcast(tv1, {true, false, false}); - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); fusion.addOutput(tv2); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = @@ -2474,13 +1619,37 @@ TEST_F(NVFuserTest, FusionTuringMatmulTTRegDoubleBuffer_CUDA) { mma_builder.configureMma(tv2); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + + // [M,K,N] -> [N,M,K] + tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [M,K,N] -> [M,N,K] + tv1cr->reorder({{-2, -1}, {-1, -2}}); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [M,K,N] -> [M,N,K] + tv2c->reorder({{-2, -1}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2490,46 +1659,74 @@ TEST_F(NVFuserTest, FusionTuringMatmulTTRegDoubleBuffer_CUDA) { auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// Matmul test on Turing -TEST_F(NVFuserTest, FusionTuringMatmulNTRegDoubleBuffer_CUDA) { +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); // [K,M,N] auto tv0b = broadcast(tv0, {false, false, true}); auto tv1b = broadcast(tv1, {false, true, false}); - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); fusion.addOutput(tv2); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + + // [K,M,N] -> [N,M,K] + tv0cr->reorder({{-3, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [K,M,N] -> [M,N,K] + tv1cr->reorder({ + {-3, -1}, + {-2, -3}, + {-1, -2}, + }); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [K,M,N] -> [M,N,K] + tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2539,7 +1736,51 @@ TEST_F(NVFuserTest, FusionTuringMatmulNTRegDoubleBuffer_CUDA) { auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// Matmul test for Turing MMA: across supported layouts +TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } } // Matmul test on ampere, using ampere memory ops @@ -3527,319 +2768,93 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { } // Matmul test on Ampere using ldmatrix.x4 to load operands -TEST_F(NVFuserTest, FusionAmpereMatmulTNLarge_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - // Call scheduler - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Ampere using ldmatrix.x4 to load operands -TEST_F(NVFuserTest, FusionAmpereMatmulNTLarge_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; - - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Ampere using ldmatrix.x4 to load operands -TEST_F(NVFuserTest, FusionAmpereMatmulTTLarge_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Turing using ldmatrix.x4 to load operands -TEST_F(NVFuserTest, FusionTuringMatmulTNLarge_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - mma_builder.configureMma(tv2); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// Matmul test on Turing using ldmatrix.x4 to load operands -TEST_F(NVFuserTest, FusionTuringMatmulTTLarge_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,K,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TT); - - mma_builder.configureMma(tv2); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = 4; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } } -// Matmul test on Turing using ldmatrix.x4 to load operands -TEST_F(NVFuserTest, FusionTuringMatmulNTLarge_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K = 128; - - // [K,M] - auto tv0 = makeContigTensor(2, DataType::Half); - // [K,N] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K,M,N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); +// Matmul test for Turing MMA: across supported layouts +TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; - fusion.addOutput(tv2); + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); + fusion.addInput(tv0); + fusion.addInput(tv1); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::NT); + auto tv2 = matmul(tv0, tv1, layout); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + fusion.addOutput(tv2); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({K, M}, options); - auto t1 = at::randn({K, N}, options); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {t0, t1})); + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) + .layout(layout); - auto cg_outputs = fe.runFusion({t0, t1}); + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + scheduleMatmul(tv2, tv0, tv1, params); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } } #undef NVFUSER_TEST_CUDA_ARCH_GUARD diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 6327ad2f78033..9a7d8607e0183 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -258,6 +258,8 @@ static const char* val_type2string(ValType t) { return "Predicate"; case ValType::TensorIndex: return "TensorIndex"; + case ValType::IntPair: + return "IntPair"; default: TORCH_INTERNAL_ASSERT(false, "No string found for val type."); } @@ -1023,6 +1025,26 @@ TORCH_CUDA_CU_API std::ostream& operator<<( return os; } +TORCH_CUDA_CU_API std::ostream& operator<<( + std::ostream& os, + const SwizzleMode& swizzle) { + switch (swizzle) { + case SwizzleMode::NoSwizzle: + os << "NoSwizzle"; + break; + case SwizzleMode::Loop: + os << "Loop"; + break; + case SwizzleMode::Data: + os << "Data"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "undefined 2D swizzle"); + break; + } + return os; +} + TORCH_CUDA_CU_API c10::optional inline_op_str( const UnaryOpType uotype) { const char* str = unary_op_type_inline_op2string(uotype); diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 5442f85296aa9..bd2f9f2ded655 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -329,6 +329,9 @@ enum class DoubleBufferLoopStage { NotApplicable, Prolog, Main, Epilog }; //! doesn't have the same type. enum class Swizzle2DType { NoSwizzle = 0, ZShape, Transpose, XOR, Scatter }; +//! Modes of swizzle, see [Note on swizzle mode]. +enum class SwizzleMode { NoSwizzle = 0, Data, Loop }; + // Returns if function needs an f suffix on the operator when operating on a // float value i.e. sin->sinf bool needFloatSuffix(UnaryOpType t); @@ -360,6 +363,7 @@ TORCH_CUDA_CU_API std::ostream& operator<<( std::ostream&, const DoubleBufferLoopStage); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const Swizzle2DType&); +TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const SwizzleMode&); std::string stringifyBooleanOp(const UnaryOpType); std::string stringifyBooleanOp(const BinaryOpType);