From 1b5dde927bf03a4b2a932ebd8cdf2724294706d0 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Sep 2025 00:14:50 +0800 Subject: [PATCH 01/10] Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability - Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`. - Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation. - Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities. - Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety. - Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks. --- src/op/gemm.cc | 40 ++-- src/op/gemm_py.cc | 259 +++++++++++++++++++++ src/op/gemm_py.h | 127 ++++++++++ src/op/gemm_sp.cc | 24 -- src/target/utils.cc | 41 +++- src/transform/lower_tile_op.cc | 31 ++- tilelang/__init__.py | 2 + tilelang/intrinsics/mma_macro_generator.py | 30 +-- tilelang/ir.py | 12 +- tilelang/language/__init__.py | 2 +- tilelang/language/gemm.py | 177 ++++++++++++++ tilelang/language/kernel.py | 8 + tilelang/layout/swizzle.py | 2 + tilelang/tileop/__init__.py | 1 + tilelang/tileop/gemm/__init__.py | 162 +++++++++++++ tilelang/utils/target.py | 53 +++++ 16 files changed, 887 insertions(+), 84 deletions(-) create mode 100644 src/op/gemm_py.cc create mode 100644 src/op/gemm_py.h create mode 100644 tilelang/tileop/__init__.py create mode 100644 tilelang/tileop/gemm/__init__.py diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 011dc8142..cae67c936 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -18,30 +18,6 @@ namespace tl { using namespace tir; -/** - * @brief Compute the prime factorization of an integer. - * - * Returns the prime factors of x in non-decreasing order by repeatedly dividing - * out the smallest possible factor. - * - * @param x Integer to factorize. If x <= 1, an empty vector is returned. - * @return std::vector Prime factors of x (with multiplicity), in - * non-decreasing order. - */ -static std::vector toPrimeFactors(int x) { - int i = 2; - std::vector result; - while (x > 1) { - if (x % i == 0) { - x /= i; - result.push_back(i); - } else { - i++; - } - } - return result; -} - /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer * map. @@ -632,5 +608,21 @@ TIR_REGISTER_TL_OP(Gemm, gemm) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TVM_REGISTER_OP("tl.GemmWarpPolicy") + .set_attr("TScriptPrinterName", "GemmWarpPolicy"); + +TVM_FFI_STATIC_INIT_BLOCK({ + GemmNode::RegisterReflection(); + GemmWarpPolicyNode::RegisterReflection(); + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition", + [](GemmWarpPolicy policy, int M, int N, int block_size, + Target target, bool is_wgmma) { + policy->ComputeWarpPartition(M, N, block_size, target, + is_wgmma); + return; + }); +}); + } // namespace tl } // namespace tvm diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc new file mode 100644 index 000000000..b20c2473e --- /dev/null +++ b/src/op/gemm_py.cc @@ -0,0 +1,259 @@ +/*! + * \file tl/op/gemm_py.cc + * \brief Implementation of General Matrix Multiplication (GEMM) operators + */ + +#include "gemm_py.h" + +#include "builtin.h" +#include +#include +#include +#include + +#include "../target/utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a Gemm operator from serialized TL arguments and a buffer + * map. + * + * This constructor deserializes operator parameters from `args` and resolves + * buffer references via `vmap`, populating an internal GemmPyNode with: + * - device pointers for A, B, C and their corresponding Buffer objects, + * - transpose flags for A and B, + * - matrix dimensions M, N, K, + * - warp allocation policy and clear_accum flag, + * - strides and memory offsets for A and B, + * - optional kPack (must be 1 or 2) and optional wg_wait. + * + * The populated GemmPyNode is stored into the wrapper's internal `data_`. + * + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] + * @param vmap Mapping from access pointer vars to Buffer objects used to + * resolve the Buffer corresponding to each pointer argument. + * + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here. + */ +GemmPy::GemmPy(Array args, BufferMap vmap) { + ObjectPtr node = make_object(); + + node->Aptr = args[0]; + node->Bptr = args[1]; + node->Cptr = args[2]; + node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; + node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; + node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; + node->trans_A = args[3].as().value(); + node->trans_B = args[4].as().value(); + node->M = args[5].as().value()->value; + node->N = args[6].as().value()->value; + node->K = args[7].as().value()->value; + node->policy = GemmWarpPolicy(args[8].as().value()->value); + node->clear_accum = args[9].as().value(); + node->stride_A = args[10].as().value()->value; + node->stride_B = args[11].as().value()->value; + node->offset_A = args[12].as().value()->value; + node->offset_B = args[13].as().value()->value; + if (args.size() > 14) { + node->kPack = args[14].as().value()->value; + if (node->kPack != 1 && node->kPack != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 15) { + node->wg_wait = args[15].as().value()->value; + } + data_ = std::move(node); +} + +/** + * @brief Create a copy of this GemmPyNode as a TileOperator. + * + * Constructs a new GemmPyNode by copying the current node state and returns it + * wrapped in a Gemm TileOperator. + * + * @return TileOperator A Gemm operator that owns a copy of this node. + */ +TileOperator GemmPyNode::Clone() const { + auto op = make_object(*this); + return GemmPy(op); +} + +GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size, + Target target) const { + int warp_size = TargetGetWarpSize(target); + int num_warps = block_size / warp_size; + bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && + (num_warps % 4 == 0) && CheckWGMMA(); + if (allow_wgmma) { + return GemmInst::kWGMMA; + } else if (TargetIsCDNA(target)) { + return GemmInst::kMFMA; + } else if (TargetIsCuda(target)) { + return GemmInst::kMMA; + } else { + ICHECK(0) << "Unsupported target for gemm: " << target->str(); + } +} + +/** + * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. + * + * Evaluates device-memory placement, data-type combinations, transpose flags, + * and K divisibility constraints required for the Hopper WGMMA code path. + * + * The check returns true only when: + * - B resides in shared memory ("shared" or "shared.dyn"); and + * - (C, A, B) dtypes match one of the supported combinations below and K + * satisfies the required alignment; and + * - for combinations that require specific orientations, A is not transposed + * and B is transposed. + * + * Supported combinations and constraints: + * - C=float16: + * - A=float16, B=float16: K % 16 == 0 + * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % + * 32 == 0 + * - C=float32: + * - A=float16, B=float16: K % 16 == 0 + * - A=bfloat16, B=bfloat16: K % 16 == 0 + * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 + * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 + * - C=int32: + * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) + * and K % 32 == 0 + * + * @return true if WGMMA is supported for the current buffers, dtypes, and + * transpose/shape constraints; false otherwise. + */ +bool GemmPyNode::CheckWGMMA() const { + if (B.scope() != "shared.dyn" && B.scope() != "shared") { + return false; + } + + if (C->dtype == DataType::Float(16)) { + if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) + return K % 16 == 0; + else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) + return (!trans_A) && trans_B && K % 32 == 0; + else + return false; + } else if (C->dtype == DataType::Float(32)) { + if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) + return K % 16 == 0; + else if (A->dtype == DataType::BFloat(16) && + B->dtype == DataType::BFloat(16)) + return K % 16 == 0; + else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32)) + return (!trans_A) && trans_B && K % 8 == 0; + else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) + return (!trans_A) && trans_B && K % 32 == 0; + else + return false; + } else if (C->dtype == DataType::Int(32)) { + if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8)) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8)) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8)) + return (!trans_A) && trans_B && K % 32 == 0; + else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8)) + return (!trans_A) && trans_B && K % 32 == 0; + else + return false; + } else { + return false; + } +} + +/** + * @brief Parse and return the numeric GPU architecture from a Target's "arch" + * attribute. + * + * Examines the target's "arch" string and, if it matches the pattern + * "sm_", returns as an int. If the attribute is present but does not + * match that pattern, returns 0. + * + * Preconditions: the target must have an "arch" attribute (this is checked via + * ICHECK). + * + * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if + * the arch string does not match "sm_". + */ +static int GetArchInt(Target target) { + int arch_int = 0; + auto s = target->GetAttr("arch"); + ICHECK(s.defined()); + std::string arch = s.value(); + if (arch.rfind("sm_", 0) == 0) { + arch_int = std::stoi(arch.substr(3)); + } else { + arch_int = 0; + } + return arch_int; +} + +Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + auto block_size = *as_const_int(T.thread_bounds->extent); + GemmInst gemm_inst = GetGemmInst(block_size, T.target); + auto [warp_m, warp_n] = policy->ComputeWarpPartition( + M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { + auto stmt = Downcast( + (*f)(GetRef(this), T.target, T.thread_bounds, T.thread_var)); + return stmt; + } else { + LOG(FATAL) << "No lower function found for gemm_py"; + } +} + +LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (completed_) + return {}; + LayoutMap results; + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { + results = Downcast( + (*f)(GetRef(this), T.target, T.thread_bounds)); + } else { + LOG(FATAL) << "No infer layout function found for gemm_py"; + } + + completed_ = true; + return results; +} + +TIR_REGISTER_TL_OP(GemmPy, gemm_py) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); +} // namespace tl +} // namespace tvm diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h new file mode 100644 index 000000000..d92cbe91c --- /dev/null +++ b/src/op/gemm_py.h @@ -0,0 +1,127 @@ +/*! + * \file tl/op/gemm_py.h + * \brief Define gemm operator. + * + */ + +#ifndef TVM_TL_OP_GEMM_PY_H_ +#define TVM_TL_OP_GEMM_PY_H_ + +#include "operator.h" +#include "gemm_py.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + + +class GemmPyNode : public TileOperatorNode { +public: + bool CheckWGMMA() const; + tir::Buffer A, B, C; + // pointer to the A, B, C + PrimExpr Aptr, Bptr, Cptr; + bool trans_A, trans_B; + int M, N, K; + int stride_A, stride_B; + int offset_A, offset_B; + bool clear_accum = false; + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack = 1; + int wg_wait = 0; + mutable GemmWarpPolicy policy; + + static constexpr const char *_type_key = "tl.GemmPy"; + TVM_DECLARE_FINAL_OBJECT_INFO(GemmPyNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("A", &GemmPyNode::A) + .def_ro("B", &GemmPyNode::B) + .def_ro("C", &GemmPyNode::C) + .def_ro("Aptr", &GemmPyNode::Aptr) + .def_ro("Bptr", &GemmPyNode::Bptr) + .def_ro("Cptr", &GemmPyNode::Cptr) + .def_ro("trans_A", &GemmPyNode::trans_A) + .def_ro("trans_B", &GemmPyNode::trans_B) + .def_ro("M", &GemmPyNode::M) + .def_ro("N", &GemmPyNode::N) + .def_ro("K", &GemmPyNode::K) + .def_ro("stride_A", &GemmPyNode::stride_A) + .def_ro("stride_B", &GemmPyNode::stride_B) + .def_ro("offset_A", &GemmPyNode::offset_A) + .def_ro("offset_B", &GemmPyNode::offset_B) + .def_ro("clear_accum", &GemmPyNode::clear_accum) + .def_ro("kPack", &GemmPyNode::kPack) + .def_ro("wg_wait", &GemmPyNode::wg_wait) + .def_ro("policy", &GemmPyNode::policy); + } + + bool SEqualReduce(const GemmPyNode *other, SEqualReducer equal) const { + return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) && + equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) && + equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) && + equal(trans_B, other->trans_B) && equal(M, other->M) && + equal(N, other->N) && equal(K, other->K) && + equal(stride_A, other->stride_A) && + equal(stride_B, other->stride_B) && + equal(offset_A, other->offset_B) && + equal(offset_B, other->offset_B) && + equal(clear_accum, other->clear_accum) && + equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && + equal(policy, other->policy); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(A); + hash_reduce(B); + hash_reduce(C); + hash_reduce(Aptr); + hash_reduce(Bptr); + hash_reduce(Cptr); + hash_reduce(trans_A); + hash_reduce(trans_B); + hash_reduce(M); + hash_reduce(N); + hash_reduce(K); + hash_reduce(stride_A); + hash_reduce(stride_B); + hash_reduce(offset_A); + hash_reduce(offset_B); + hash_reduce(clear_accum); + hash_reduce(kPack); + hash_reduce(wg_wait); + hash_reduce(policy); + } + static constexpr bool _type_has_method_sequal_reduce = true; + static constexpr bool _type_has_method_shash_reduce = true; + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + +private: + // Target GEMM instruction + enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA }; + GemmInst GetGemmInst(int block_size, Target target) const; + + mutable bool completed_ = false; +}; + +class GemmPy : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(GemmPy, TileOperator, GemmPyNode); + TVM_DLL GemmPy(Array args, BufferMap vmap); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_PY_H_ \ No newline at end of file diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index d4784e930..74e0f1950 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -17,30 +17,6 @@ namespace tvm { namespace tl { -/** - * @brief Decomposes a positive integer into its prime factors. - * - * Returns the prime factorization of `x` as a vector of prime factors in - * non-decreasing order. If `x <= 1` the returned vector is empty. - * - * @param x Integer to factorize (expected non-negative; behavior: returns empty - * for values <= 1). - * @return std::vector Prime factors of `x` (with repetition), e.g. 12 -> - * {2, 2, 3}. - */ -static std::vector toPrimeFactors(int x) { - int i = 2; - std::vector result; - while (x > 1) { - if (x % i == 0) { - x /= i; - result.push_back(i); - } else { - i++; - } - } - return result; -} /** * @brief Construct a GemmSP operator node from TL call arguments and a buffer diff --git a/src/target/utils.cc b/src/target/utils.cc index 35135c1dc..6ce2425ca 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -18,11 +18,11 @@ bool TargetIsRocm(Target target) { int GetArchInt(Target target) { auto s = target->GetAttr("arch"); ICHECK(s.defined()); - const char *arch_str = s.value().c_str(); - ICHECK_EQ(arch_str[0], 's'); - ICHECK_EQ(arch_str[1], 'm'); - ICHECK_EQ(arch_str[2], '_'); - return atoi(&arch_str[3]); + const std::string arch_str = s.value(); + ICHECK(arch_str.size() >= 3); + ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0) + << "arch string must start with sm_"; + return std::stoi(arch_str.substr(3)); } bool TargetIsVolta(Target target) { @@ -118,5 +118,36 @@ int TargetGetWarpSize(Target target) { return res; } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tl.TargetIsCuda", + [](Target target) { return TargetIsCuda(target); }) + .def("tl.TargetIsRocm", + [](Target target) { return TargetIsRocm(target); }) + .def("tl.TargetIsVolta", + [](Target target) { return TargetIsVolta(target); }) + .def("tl.TargetIsTuring", + [](Target target) { return TargetIsTuring(target); }) + .def("tl.TargetIsAmpere", + [](Target target) { return TargetIsAmpere(target); }) + .def("tl.TargetIsHopper", + [](Target target) { return TargetIsHopper(target); }) + .def("tl.TargetIsSM120", + [](Target target) { return TargetIsSM120(target); }) + .def("tl.TargetIsCDNA", + [](Target target) { return TargetIsCDNA(target); }) + .def("tl.TargetHasAsyncCopy", + [](Target target) { return TargetHasAsyncCopy(target); }) + .def("tl.TargetHasLdmatrix", + [](Target target) { return TargetHasLdmatrix(target); }) + .def("tl.TargetHasStmatrix", + [](Target target) { return TargetHasStmatrix(target); }) + .def("tl.TargetHasBulkCopy", + [](Target target) { return TargetHasBulkCopy(target); }) + .def("tl.TargetGetWarpSize", + [](Target target) { return TargetGetWarpSize(target); }); +}); + } // namespace tl } // namespace tvm diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 708e2526c..d0a9c674a 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -303,26 +303,27 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } else if (access_ptr_call->op.same_as(builtin::address_of())) { BufferLoad load = Downcast(access_ptr_call->args[0]); Array indices = load->indices; - Array shape = load->buffer->shape; + Array old_shape = load->buffer->shape; - CHECK_EQ(indices.size(), shape.size()) + CHECK_EQ(indices.size(), old_shape.size()) << "Indices size and shape size must match for general N-dimensional " "buffer " << "but got indices size: " << indices.size() - << " and shape size: " << shape.size(); + << " and shape size: " << old_shape.size(); PrimExpr elem_offset = 0; PrimExpr stride = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + for (int i = static_cast(old_shape.size()) - 1; i >= 0; --i) { elem_offset += indices[i] * stride; - stride *= shape[i]; + stride *= old_shape[i]; } PrimExpr smem_offset = elem_offset + (offset.defined() ? offset.value() : 0); auto new_buffer = buffer_remap_[load->buffer]; + auto new_shape = new_buffer->shape; auto buffer_map_iter = buffer_map_.find(Downcast(load->buffer->data)); @@ -337,26 +338,27 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { Array multi_dim_indices; PrimExpr remaining_offset = smem_offset; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + for (int i = static_cast(old_shape.size()) - 1; i >= 0; --i) { multi_dim_indices.insert(multi_dim_indices.begin(), - floormod(remaining_offset, shape[i])); - remaining_offset = floordiv(remaining_offset, shape[i]); + floormod(remaining_offset, old_shape[i])); + remaining_offset = floordiv(remaining_offset, old_shape[i]); } auto forward_indices = layout_map_[load->buffer]->Forward(multi_dim_indices); PrimExpr new_offset = 0; PrimExpr stride_offset = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { new_offset += forward_indices[i] * stride_offset; - stride_offset *= shape[i]; + stride_offset *= new_shape[i]; } new_offset = analyzer_->Simplify(new_offset); Array new_indices; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - new_indices.insert(new_indices.begin(), floormod(new_offset, shape[i])); - new_offset = floordiv(new_offset, shape[i]); + for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { + new_indices.insert(new_indices.begin(), + floormod(new_offset, new_shape[i])); + new_offset = floordiv(new_offset, new_shape[i]); } auto new_access_ptr = access_ptr_call.CopyOnWrite(); @@ -397,7 +399,6 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr; } BufferLoad load = Downcast(address_of_call->args[0]); - if (buffer_remap_.count(load->buffer)) { auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); @@ -494,9 +495,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { * visitor processing. */ Stmt VisitStmt_(const EvaluateNode *op) final { - // LOG(INFO) << "evaluate node: " << op->value; const CallNode *call = op->value.as(); - // LOG(INFO) << "call: " << call->op; // Do not analysis the call node to the global function. if (call && call->op.as()) return Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 4fe8ddea6..96d611bd0 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -106,3 +106,5 @@ def _load_tile_lang_lib(): from .math import * # noqa: F403 from . import ir # noqa: F401 + +from . import tileop # noqa: F401 diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 8d4d43ebc..547f40f5b 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -2,7 +2,7 @@ from typing import Union, Tuple, Optional, Literal, Callable from tilelang.common import TransformKind from tvm import DataType -from tvm.tir import PrimExpr, IndexMap, Buffer +from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.runtime import convert from .utils import ( mma_store_index_map, @@ -50,6 +50,7 @@ def __init__( reduce_k: int = 1, num_elems_per_byte: int = 1, is_m_first: Optional[bool] = False, + thread_var: Optional[Var] = None, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -74,6 +75,7 @@ def __init__( self.reduce_k = reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var if self.warp_rows == 0 or self.warp_cols == 0: raise ValueError( @@ -112,6 +114,14 @@ def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): if is_m_first is not None: self.is_m_first = is_m_first + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") @@ -166,8 +176,7 @@ def ldmatrix_a(self, a_dtype = self.a_dtype a_transposed = self.a_transposed - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() @T.macro def _warp_ldmatrix_a( @@ -209,8 +218,7 @@ def ldmatrix_b(self, local_size_b = self.local_size_b b_dtype = self.b_dtype b_transposed = self.b_transposed - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() @T.macro def _warp_ldmatrix_b( @@ -229,7 +237,6 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) - B_shared_elem = B_shared_buf[ri, rj] T.ptx_ldmatrix( b_dtype, @@ -238,7 +245,7 @@ def _warp_ldmatrix_b( ".b16", B_local_buf.data, j * local_size_b, - T.address_of(B_shared_elem), + T.address_of(B_shared_buf[ri, rj]), get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), ) @@ -318,8 +325,7 @@ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): C_buf_dims = len(C_buf.shape) assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() # STS # MMA Store must be in simulated instead of TVM Intrins @@ -632,8 +638,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): a_transposed = self.a_transposed transform_kind_a = self.transform_kind_a - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() @T.macro def _warp_ldmatrix_a( @@ -740,8 +745,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): b_transposed = self.b_transposed num_elems_per_byte = self.num_elems_per_byte - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() @T.macro def _warp_ldmatrix_b( diff --git a/tilelang/ir.py b/tilelang/ir.py index d6bdc4aa0..d48aeeed8 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -2,6 +2,8 @@ from tvm.ir.base import Node from tvm.runtime import Scriptable import tvm.ffi +from tvm.target import Target +from tilelang import _ffi_api @tvm.ffi.register_object("tl.Fill") @@ -26,7 +28,15 @@ class Conv2DIm2ColOp(Node, Scriptable): @tvm.ffi.register_object("tl.GemmWarpPolicy") class GemmWarpPolicy(Node, Scriptable): - ... + policy_type: int + m_warp: int + n_warp: int + + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, + is_wgmma: bool): + _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, + is_wgmma) + return self.m_warp, self.n_warp @tvm.ffi.register_object("tl.Gemm") diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index bd1a10881..6d22a14d6 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -43,7 +43,7 @@ alloc_barrier, # noqa: F401 ) from .copy import copy, c2d_im2col # noqa: F401 -from .gemm import GemmWarpPolicy, gemm # noqa: F401 +from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401 from .experimental.gemm_sp import gemm_sp # noqa: F401 from .fill import fill, clear # noqa: F401 from .reduce import ( diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index aab540ed2..1cd5c8136 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -180,3 +180,180 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr k_pack, wg_wait, ) + + +# experimental currently, for fast compilation +def gemm_v2( + A: Union[tir.Buffer, tir.Var], + B: Union[tir.Buffer, tir.Var], + C: Union[tir.Buffer, tir.Var], + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +): + """Perform a General Matrix Multiplication (GEMM) operation. + + This function computes C = A @ B where A and B can optionally be transposed. + The operation supports various warp policies and accumulation modes. + + Args: + A (Union[tir.Buffer, tir.Var]): First input matrix + B (Union[tir.Buffer, tir.Var]): Second input matrix + C (Union[tir.Buffer, tir.Var]): Output matrix for results + transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. + transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. + clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. + k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. + wg_wait (int, optional): Warp group wait count. Defaults to 0. + + Returns: + tir.Call: A handle to the GEMM operation + + Raises: + AssertionError: If the K dimensions of matrices A and B don't match + """ + + def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A = legalize_arguments(A) + B = legalize_arguments(B) + C = legalize_arguments(C) + + def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: + if isinstance(object, tir.Buffer): + return object.shape + elif isinstance(object, tir.BufferRegion): + region = object.region + shape = [] + for r in region: + shape.append(r.extent) + return shape + else: + raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + + def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: + if isinstance(object, tir.Buffer): + strides = [] + stride = 1 + for s in reversed(object.shape): + strides.insert(0, stride) + stride *= s + return strides + elif isinstance(object, tir.BufferRegion): + buffer, _ = object.buffer, object.region + strides = [] + stride = 1 + for s in reversed(buffer.shape): + strides.insert(0, stride) + stride *= s + return strides + else: + raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + + A_shape = retrieve_shape(A) + B_shape = retrieve_shape(B) + C_shape = retrieve_shape(C) + + A_stride = retrieve_stride(A) + B_stride = retrieve_stride(B) + + assert len(C_shape) == 2, "current only support C as a 2D tensor" + assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" + assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" + if len(A_shape) > 2: + for i in range(len(A_shape) - 2): + assert A_shape[i] == 1, \ + "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + if len(B_shape) > 2: + for i in range(len(B_shape) - 2): + assert B_shape[i] == 1, \ + "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + + M, N = C_shape + K = A_shape[-2] if transpose_A else A_shape[-1] + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}" + + stride_a = A_stride[-2] + stride_b = B_stride[-2] + + def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], + access_type: str = "r") -> tir.PrimExpr: + if isinstance(object, tir.Buffer): + return object.access_ptr(access_type) + elif isinstance(object, tir.BufferRegion): + buffer, region = object.buffer, object.region + indices = [] + for r in region: + indices.append(r.min) + strides = [] + stride = 1 + for s in reversed(buffer.shape): + strides.insert(0, stride) + stride *= s + offset = 0 + # not offset the last two dimension + for i in range(len(indices) - 2): + offset += indices[i] * strides[i] + return buffer.access_ptr(access_mask=access_type, offset=offset) + else: + raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + + def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr: + """Retrieve the offset of the buffer or buffer region.""" + if isinstance(object, tir.Buffer): + return [0] * len(object.shape) + elif isinstance(object, tir.BufferRegion): + _, region = object.buffer, object.region + indices = [] + for r in region: + indices.append(r.min) + return indices + else: + raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + + A_offset = retrieve_offset(A) + B_offset = retrieve_offset(B) + assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" + assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" + offset_a = A_offset[-1] + offset_b = B_offset[-1] + + Aptr = retrieve_ptr(A, "r") + Bptr = retrieve_ptr(B, "r") + Cptr = retrieve_ptr(C, "rw") + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.gemm_py"), + Aptr, + Bptr, + Cptr, + transpose_A, + transpose_B, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + ) diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 0ce6e6ece..3f61e70db 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -261,46 +261,54 @@ def Kernel( def get_thread_binding(dim: int = 0) -> Var: """Returns the thread binding for the given dimension. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_binding(dim) def get_thread_bindings() -> List[Var]: """Returns all three thread bindings. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_bindings() def get_block_binding(dim: int = 0) -> Var: """Returns the block binding for the given dimension. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_binding(dim) def get_block_bindings() -> List[Var]: """Returns all three block bindings. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_bindings() def get_thread_extent(dim: int = 0) -> int: """Returns the thread extent for the given dimension. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_extent(dim) def get_thread_extents() -> List[int]: """Returns all three thread extents. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_extents() def get_block_extent(dim: int = 0) -> int: """Returns the block extent for the given dimension. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_extent(dim) def get_block_extents() -> List[int]: """Returns all three block extents. """ + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_extents() diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index d1087bd23..9fd2582b3 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -5,6 +5,8 @@ from tilelang import _ffi_api +# Use a stable swizzled layout to ensure consistent memory access patterns. +# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. def make_swizzled_layout(buffer: tvm.tir.Buffer): assert len(buffer.shape) == 2 return _ffi_api.make_swizzled_layout( diff --git a/tilelang/tileop/__init__.py b/tilelang/tileop/__init__.py new file mode 100644 index 000000000..13e6c043d --- /dev/null +++ b/tilelang/tileop/__init__.py @@ -0,0 +1 @@ +from .gemm import GemmPy diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py new file mode 100644 index 000000000..81fe9353e --- /dev/null +++ b/tilelang/tileop/gemm/__init__.py @@ -0,0 +1,162 @@ +from tilelang import tvm as tvm +from tvm import tir +from tilelang.utils.target import ( + target_is_cuda, + target_is_hip, +) +from tilelang import _ffi_api +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter,) +from tilelang.layout import make_swizzled_layout +from tilelang import language as T +from tvm.target import Target +from tvm.ir.base import Node +from tvm.runtime import Scriptable +import tvm.ffi +from tilelang.ir import GemmWarpPolicy + + +@tvm.ffi.register_func("tl.gemm_py.infer_layout") +def gemm_py_infer_layout(gemm_py, target, thread_bounds): + thread_nums = thread_bounds.extent + return gemm_py.infer_layout(target, thread_nums) + +@tvm.ffi.register_func("tl.gemm_py.lower") +def gemm_py_lower(gemm_py, target, thread_bounds, thread_var): + thread_nums = thread_bounds.extent + stmt = gemm_py.lower(target, thread_nums, thread_var) + return stmt + +@tvm.ffi.register_object("tl.GemmPy") +class GemmPy(Node, Scriptable): + A: tir.Buffer + B: tir.Buffer + C: tir.Buffer + + APtr: tir.PrimExpr + BPtr: tir.PrimExpr + CPtr: tir.PrimExpr + + M: int + N: int + K: int + + trans_A: bool + trans_B: bool + + stride_A: int + stride_B: int + offset_A: int + offset_B: int + clear_accum: bool + k_pack: int + wg_wait: int + policy: GemmWarpPolicy + + + def infer_layout(self, target: Target, thread_nums: int): + if target_is_cuda(target): + # TODO(lei): Support more cuda architectures, now mma only + # Now only implement ssr layout + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = m_warp * 16 + warp_col_tiles = n_warp * 16 + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=self.M, + block_col_warps=self.N, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + layout_map = { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + return layout_map + else: + raise ValueError(f"Unsupported target: {target}") + + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + if target_is_cuda(target): + # TODO(lei): Support more cuda architectures, now mma only + # Now only implement ssr layout + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = m_warp * 16 + warp_col_tiles = n_warp * 16 + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=self.M, + block_col_warps=self.N, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + block_K = mma_emitter.chunk + micro_size_k = mma_emitter.micro_size_k + A_shared = self.A + B_shared = self.B + C_local = self.C + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + return _gemm_ssr.body + else: + raise ValueError(f"Unsupported target: {target}") + + + @property + def in_dtype(self) -> str: + assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" + return self.A.dtype + + @property + def accum_dtype(self) -> str: + return self.C.dtype + + @property + def chunk(self) -> int: + return self.A.shape[-2] if self.trans_A else self.A.shape[-1] \ No newline at end of file diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 9e12115a2..ed696c29a 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -1,5 +1,6 @@ from typing import Literal, Union from tilelang import tvm as tvm +from tilelang import _ffi_api from tvm.target import Target from tvm.contrib import rocm from tilelang.contrib import nvcc @@ -81,3 +82,55 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", if return_object: return Target(return_var) return return_var + + +def target_is_cuda(target: Target) -> bool: + return _ffi_api.TargetIsCuda(target) + + +def target_is_hip(target: Target) -> bool: + return _ffi_api.TargetIsRocm(target) + + +def target_is_volta(target: Target) -> bool: + return _ffi_api.TargetIsVolta(target) + + +def target_is_turing(target: Target) -> bool: + return _ffi_api.TargetIsTuring(target) + + +def target_is_ampere(target: Target) -> bool: + return _ffi_api.TargetIsAmpere(target) + + +def target_is_hopper(target: Target) -> bool: + return _ffi_api.TargetIsHopper(target) + + +def target_is_sm120(target: Target) -> bool: + return _ffi_api.TargetIsSM120(target) + + +def target_is_cdna(target: Target) -> bool: + return _ffi_api.TargetIsCDNA(target) + + +def target_has_async_copy(target: Target) -> bool: + return _ffi_api.TargetHasAsyncCopy(target) + + +def target_has_ldmatrix(target: Target) -> bool: + return _ffi_api.TargetHasLdmatrix(target) + + +def target_has_stmatrix(target: Target) -> bool: + return _ffi_api.TargetHasStmatrix(target) + + +def target_has_bulk_copy(target: Target) -> bool: + return _ffi_api.TargetHasBulkCopy(target) + + +def target_get_warp_size(target: Target) -> int: + return _ffi_api.TargetGetWarpSize(target) From 52800a5fda5fb6d7ecd05cbc23e2ebfa67f3a62e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Sep 2025 01:23:46 +0800 Subject: [PATCH 02/10] Refactor GEMM and frontend legalize operations for improved clarity and functionality - Updated `gemm_py.h` to include the correct header for GEMM operations. - Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity. - Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose. - Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity. - Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite. - Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process. --- src/op/gemm_py.h | 3 +- src/transform/frontend_legalize.cc | 12 ++--- .../test_tilelang_tilelibrary_gemm.py | 12 +++-- ... => test_tilelang_transform_let_inline.py} | 2 +- tilelang/engine/phase.py | 4 +- tilelang/intrinsics/mma_macro_generator.py | 17 ++++--- tilelang/tileop/__init__.py | 2 +- tilelang/tileop/gemm/__init__.py | 45 ++++++++++--------- tilelang/transform/__init__.py | 13 +----- tilelang/transform/simplify.py | 33 +++++++++++--- 10 files changed, 83 insertions(+), 60 deletions(-) rename testing/python/transform/{test_tilelang_transform_frontend_legalize.py => test_tilelang_transform_let_inline.py} (97%) diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index d92cbe91c..fa3e22c1e 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -7,8 +7,8 @@ #ifndef TVM_TL_OP_GEMM_PY_H_ #define TVM_TL_OP_GEMM_PY_H_ +#include "gemm.h" #include "operator.h" -#include "gemm_py.h" namespace tvm { @@ -16,7 +16,6 @@ namespace tl { using namespace tir; - class GemmPyNode : public TileOperatorNode { public: bool CheckWGMMA() const; diff --git a/src/transform/frontend_legalize.cc b/src/transform/frontend_legalize.cc index 3326d8ea7..b366d02d1 100644 --- a/src/transform/frontend_legalize.cc +++ b/src/transform/frontend_legalize.cc @@ -34,11 +34,11 @@ namespace tl { using namespace tir; -class FrontendLegalizer : public arith::IRMutatorWithAnalyzer { +class LetInliner : public arith::IRMutatorWithAnalyzer { public: static PrimFunc Substitute(PrimFunc f) { arith::Analyzer analyzer; - FrontendLegalizer substituter(&analyzer); + LetInliner substituter(&analyzer); PrimFuncNode *fptr = f.CopyOnWrite(); fptr->body = substituter.VisitStmt(f->body); return f; @@ -82,16 +82,16 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer { using namespace tir::transform; -Pass FrontendLegalize() { +Pass LetInline() { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { - return FrontendLegalizer::Substitute(std::move(f)); + return LetInliner::Substitute(std::move(f)); }; - return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {}); + return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {}); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.FrontendLegalize", FrontendLegalize); + refl::GlobalDef().def("tl.transform.LetInline", LetInline); }); } // namespace tl diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index fdfab324f..bbadd785f 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -44,7 +44,7 @@ def main( T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -88,6 +88,7 @@ def run_gemm( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) + print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): @@ -157,7 +158,7 @@ def main( T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -224,4 +225,9 @@ def test_gemm_rs(): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + tilelang.disable_cache() + # run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 nt ss done") + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) + print("gemm fp16 nn ss done") \ No newline at end of file diff --git a/testing/python/transform/test_tilelang_transform_frontend_legalize.py b/testing/python/transform/test_tilelang_transform_let_inline.py similarity index 97% rename from testing/python/transform/test_tilelang_transform_frontend_legalize.py rename to testing/python/transform/test_tilelang_transform_let_inline.py index e57a97026..aa2638af1 100644 --- a/testing/python/transform/test_tilelang_transform_frontend_legalize.py +++ b/testing/python/transform/test_tilelang_transform_let_inline.py @@ -7,7 +7,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tl.transform.FrontendLegalize()(mod) + mod = tl.transform.LetInline()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 646cb66c1..b8ac49a9a 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -85,8 +85,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: """ mod = tir.transform.BindTarget(target)(mod) - # Legalize the frontend IR to make it compatible with TVM - mod = tilelang.transform.FrontendLegalize()(mod) + # Inline let expressions and statements + mod = tilelang.transform.LetInline()(mod) # Inject assumes to speedup tvm prover mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 547f40f5b..d155efb46 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -188,10 +188,12 @@ def _warp_ldmatrix_a( ): stride = A_shared_buf.shape[-1] tx, _, warp_m = self.extract_thread_binding(thread_binding) + trans = self.a_transposed + for i in T.serial(warp_rows): T.ptx_ldmatrix( a_dtype, - T.bool(False), + T.bool(trans), 4, ".b16", A_local_buf.data, @@ -230,22 +232,25 @@ def _warp_ldmatrix_b( ): stride = B_shared_buf.shape[-1] tx, warp_n, _ = self.extract_thread_binding(thread_binding) + trans = not self.b_transposed for j in T.serial(warp_cols): # Assign B_shared_elem - ri, rj = ( + wi, wk = ( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) + B_shared_buf_elem = B_shared_buf[wi, wk] if self.b_transposed else B_shared_buf[wk, + wi] T.ptx_ldmatrix( b_dtype, - T.bool(False), # TODO(lei): should be optimized + T.bool(trans), 4, ".b16", B_local_buf.data, j * local_size_b, - T.address_of(B_shared_buf[ri, rj]), + T.address_of(B_shared_buf_elem), get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), ) @@ -289,7 +294,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): b_local_stride + j * local_size_b, C_local_buf.data, i * warp_cols * local_size_out + j * local_size_out, - T.bool(False), + T.bool(False), # saturate ) T.ptx_mma( @@ -306,7 +311,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): b_local_stride + j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), + T.bool(False), # saturate ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/tilelang/tileop/__init__.py b/tilelang/tileop/__init__.py index 13e6c043d..5656494fe 100644 --- a/tilelang/tileop/__init__.py +++ b/tilelang/tileop/__init__.py @@ -1 +1 @@ -from .gemm import GemmPy +from .gemm import GemmPy # noqa: F401 diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 81fe9353e..27e1eea20 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -1,10 +1,7 @@ from tilelang import tvm as tvm from tvm import tir from tilelang.utils.target import ( - target_is_cuda, - target_is_hip, -) -from tilelang import _ffi_api + target_is_cuda,) from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter,) from tilelang.layout import make_swizzled_layout @@ -14,6 +11,7 @@ from tvm.runtime import Scriptable import tvm.ffi from tilelang.ir import GemmWarpPolicy +from tilelang.transform.simplify import _Simplify @tvm.ffi.register_func("tl.gemm_py.infer_layout") @@ -21,18 +19,20 @@ def gemm_py_infer_layout(gemm_py, target, thread_bounds): thread_nums = thread_bounds.extent return gemm_py.infer_layout(target, thread_nums) + @tvm.ffi.register_func("tl.gemm_py.lower") def gemm_py_lower(gemm_py, target, thread_bounds, thread_var): thread_nums = thread_bounds.extent stmt = gemm_py.lower(target, thread_nums, thread_var) return stmt + @tvm.ffi.register_object("tl.GemmPy") class GemmPy(Node, Scriptable): A: tir.Buffer B: tir.Buffer C: tir.Buffer - + APtr: tir.PrimExpr BPtr: tir.PrimExpr CPtr: tir.PrimExpr @@ -52,23 +52,23 @@ class GemmPy(Node, Scriptable): k_pack: int wg_wait: int policy: GemmWarpPolicy - def infer_layout(self, target: Target, thread_nums: int): if target_is_cuda(target): # TODO(lei): Support more cuda architectures, now mma only # Now only implement ssr layout - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) - warp_row_tiles = m_warp * 16 - warp_col_tiles = n_warp * 16 + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, - block_row_warps=self.M, - block_col_warps=self.N, + block_row_warps=m_warp, + block_col_warps=n_warp, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, @@ -81,23 +81,23 @@ def infer_layout(self, target: Target, thread_nums: int): return layout_map else: raise ValueError(f"Unsupported target: {target}") - def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): if target_is_cuda(target): # TODO(lei): Support more cuda architectures, now mma only # Now only implement ssr layout - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) - warp_row_tiles = m_warp * 16 - warp_col_tiles = n_warp * 16 + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, - block_row_warps=self.M, - block_col_warps=self.N, + block_row_warps=m_warp, + block_col_warps=n_warp, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, @@ -125,7 +125,6 @@ def _gemm_ssr() -> None: A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - for ki in T.serial(0, (block_K // micro_size_k)): # Load A into fragment mma_emitter.ldmatrix_a( @@ -143,10 +142,12 @@ def _gemm_ssr() -> None: # Perform Matrix Multiplication mma_emitter.mma(A_local, B_local, C_local) - return _gemm_ssr.body + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True).body else: raise ValueError(f"Unsupported target: {target}") - @property def in_dtype(self) -> str: @@ -156,7 +157,7 @@ def in_dtype(self) -> str: @property def accum_dtype(self) -> str: return self.C.dtype - + @property def chunk(self) -> int: - return self.A.shape[-2] if self.trans_A else self.A.shape[-1] \ No newline at end of file + return self.A.shape[-2] if self.trans_A else self.A.shape[-1] diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index da8cf51d9..e438d0864 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -2,7 +2,7 @@ # pylint: disable=invalid-name, unsupported-binary-operation from . import _ffi_api -from .simplify import Simplify, simplify_prim_func # noqa: F401 +from .simplify import Simplify, simplify_prim_func, LetInline # noqa: F401 from .pass_config import PassConfigKey # noqa: F401 from tilelang import tvm as tvm # noqa: F401 from tvm.ir.transform import PassContext # noqa: F401 @@ -68,17 +68,6 @@ def InjectSoftwarePipeline(): return _ffi_api.InjectSoftwarePipeline() # type: ignore -def FrontendLegalize(): - """FrontendLegalize - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.FrontendLegalize() # type: ignore - - def InjectAssumes(): """Inject Assumes diff --git a/tilelang/transform/simplify.py b/tilelang/transform/simplify.py index fd1dac38f..6b8fedfc3 100644 --- a/tilelang/transform/simplify.py +++ b/tilelang/transform/simplify.py @@ -5,6 +5,17 @@ from . import _ffi_api +def LetInline(): + """LetInline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LetInline() # type: ignore + + def Simplify(simplify_arguments: bool = False): """Simplify @@ -16,13 +27,24 @@ def Simplify(simplify_arguments: bool = False): return _ffi_api.Simplify(simplify_arguments) # type: ignore -def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]: +def _Simplify(stmt: Union[PrimFunc, IRModule], + inline_let: bool = False) -> Union[PrimFunc, IRModule]: if isinstance(stmt, PrimFunc): - mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt)) + if inline_let: + mod = LetInline()(IRModule.from_expr(stmt)) + mod = Simplify(simplify_arguments=True)(mod) + else: + mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt)) assert len(mod.functions) == 1, "Simplify should return a single function" return list(mod.functions.values()).pop() elif isinstance(stmt, IRModule): - return Simplify(simplify_arguments=True)(stmt) + if inline_let: + mod = LetInline()(stmt) + mod = Simplify(simplify_arguments=True)(mod) + else: + mod = Simplify(simplify_arguments=True)(stmt) + assert len(mod.functions) == 1, "Simplify should return a single function" + return list(mod.functions.values()).pop() else: raise ValueError(f"Unsupported type: {type(stmt)}") @@ -37,6 +59,7 @@ def wrapper(*args, **kwargs): return wrapper -def apply_simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]: +def apply_simplify(stmt: Union[PrimFunc, IRModule], + inline_let: bool = False) -> Union[PrimFunc, IRModule]: """Apply Simplify pass to a PrimFunc or IRModule.""" - return _Simplify(stmt) + return _Simplify(stmt, inline_let) From e299b419c780b71f95729d7928e2c69d0f35c41c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Sep 2025 15:19:54 +0800 Subject: [PATCH 03/10] Enhance CUDA code generation and testing for GEMM operations - Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting. - Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope. - Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts. - Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling. - Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework. --- src/target/codegen_cuda.cc | 17 +- .../test_tilelang_tilelibrary_gemm.py | 152 ++++++++++++++++-- tilelang/intrinsics/mma_macro_generator.py | 19 +-- tilelang/tileop/gemm/__init__.py | 134 ++++++++++----- 4 files changed, 258 insertions(+), 64 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 2a4bb9c17..7e7d1456a 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1259,7 +1259,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string asm_code = PrintMMAAssembly( shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); - + this->PrintIndent(); this->stream << asm_code; } else if (op->op.same_as(builtin::ptx_mma_sp())) { // arg 0: shape: mXnXkX @@ -1295,6 +1295,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string metadata_offset = this->PrintExpr(op->args[13]); std::string sparse_selector = this->PrintExpr(op->args[14]); bool saturate = Downcast(op->args[15])->value; + this->PrintIndent(); std::string asm_code = PrintMMAAssembly( shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, @@ -1330,10 +1331,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "}\n"; } else { std::string smem_elem_offset = this->PrintExpr(op->args[6]); - need_cast_smem_ptr_to_int_ = true; - this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, - local_elem_offset, smem_ptr, - smem_elem_offset); + // need_cast_smem_ptr_to_int_ = true; + // this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, + // local_elem_offset, smem_ptr, + // smem_elem_offset); + std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + // this->stream << func_name << "(" << local_ptr "" << ", " << smem_ptr << ");\n"; + this->PrintIndent(); + this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset<< ", " << local_ptr << " + " << local_elem_offset << ");\n"; } } else if (op->op.same_as(builtin::mma_store())) { int m = Downcast(op->args[0])->value; diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index bbadd785f..714f8cef8 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -1,3 +1,4 @@ +from asyncio import threads from tilelang import tvm as tvm import tilelang.testing @@ -31,8 +32,8 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -108,8 +109,11 @@ def ref_program(A, B): def test_gemm(): # More test case can be found in kernel/test_tilelang_kernel_gemm.py # GEMM tests for float16 - run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, - 2) # f16f16f16_nn + run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) + def matmul_rs( @@ -142,23 +146,26 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + }) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) - T.copy(A_shared, A_frag) else: T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(A_shared, A_frag) if trans_B: T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -202,6 +209,7 @@ def run_gemm_rs( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) + print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): @@ -224,10 +232,134 @@ def test_gemm_rs(): run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout({ + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + }) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_sr(): + # GEMM tests for float16 + run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + + if __name__ == "__main__": # tilelang.testing.main() tilelang.disable_cache() + tilelang.testing.set_random_seed(42) # run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) # print("gemm fp16 nt ss done") - run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) - print("gemm fp16 nn ss done") \ No newline at end of file + # run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 nn ss done") + # run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 tn ss done") + # run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 tt ss done") + # run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128) + # print("gemm fp16 nt rs done") + run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128) + # run_gemm(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index d155efb46..f5a9123e9 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -189,8 +189,12 @@ def _warp_ldmatrix_a( stride = A_shared_buf.shape[-1] tx, _, warp_m = self.extract_thread_binding(thread_binding) trans = self.a_transposed - + for i in T.serial(warp_rows): + # Assign A_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k + A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] + T.ptx_ldmatrix( a_dtype, T.bool(trans), @@ -198,10 +202,7 @@ def _warp_ldmatrix_a( ".b16", A_local_buf.data, i * local_size_a, - T.address_of(A_shared_buf[ - warp_m * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * micro_size_k, - ]), + T.address_of(A_shared_buf_elem), get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), ) @@ -232,7 +233,7 @@ def _warp_ldmatrix_b( ): stride = B_shared_buf.shape[-1] tx, warp_n, _ = self.extract_thread_binding(thread_binding) - trans = not self.b_transposed + trans = not b_transposed for j in T.serial(warp_cols): # Assign B_shared_elem @@ -240,8 +241,7 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) - B_shared_buf_elem = B_shared_buf[wi, wk] if self.b_transposed else B_shared_buf[wk, - wi] + B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi] T.ptx_ldmatrix( b_dtype, @@ -470,9 +470,6 @@ def forward_index(i: int, j: int) -> int: block_fragment = warp_fragment.repeat([warp_s, chunk // micro_size_r], repeat_on_thread=False, lower_dim_first=False) - print(f"base_fragment: {base_fragment}") - print(f"warp_fragment: {warp_fragment}") - print(f"block_fragment: {block_fragment}") return block_fragment def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 27e1eea20..81d25cf1f 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -12,7 +12,7 @@ import tvm.ffi from tilelang.ir import GemmWarpPolicy from tilelang.transform.simplify import _Simplify - +from tilelang.utils.language import is_shared, is_fragment @tvm.ffi.register_func("tl.gemm_py.infer_layout") def gemm_py_infer_layout(gemm_py, target, thread_bounds): @@ -73,12 +73,25 @@ def infer_layout(self, target: Target, thread_nums: int): warp_col_tiles=warp_col_tiles, chunk=self.chunk, ) - layout_map = { - self.A: make_swizzled_layout(self.A), - self.B: make_swizzled_layout(self.B), - self.C: mma_emitter.make_mma_store_layout(self.C), - } - return layout_map + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_sr(): + raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + elif self.is_gemm_rs(): + return { + # make mma load layout or ldmatrix layout? + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.C: mma_emitter.make_mma_store_layout(self.C), + self.B: make_swizzled_layout(self.B), + } + elif self.is_gemm_rr(): + raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") else: raise ValueError(f"Unsupported target: {target}") @@ -115,37 +128,70 @@ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): B_shared = self.B C_local = self.C - @T.prim_func - def _gemm_ssr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_ssr, inline_let=True).body + if self.is_gemm_ss(): + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True).body + elif self.is_gemm_sr(): + raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + elif self.is_gemm_rs(): + A_local = self.A + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True).body + elif self.is_gemm_rr(): + raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") else: raise ValueError(f"Unsupported target: {target}") @@ -161,3 +207,15 @@ def accum_dtype(self) -> str: @property def chunk(self) -> int: return self.A.shape[-2] if self.trans_A else self.A.shape[-1] + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) From 1ab46ef8abd7f4eb3126efeee233052996c90c62 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 9 Sep 2025 15:33:55 +0800 Subject: [PATCH 04/10] Refactor GEMM layout and testing for improved clarity and functionality - Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations. - Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage. - Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations. - Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework. --- 3rdparty/tvm | 2 +- src/layout/gemm_layouts.cc | 6 +- src/op/gemm_py.cc | 14 +- src/transform/inject_pipeline.cc | 6 + .../test_tilelang_tilelibrary_gemm.py | 184 ++++++++++++++++-- tilelang/engine/phase.py | 4 + tilelang/intrinsics/mma_layout.py | 20 +- tilelang/intrinsics/mma_macro_generator.py | 91 ++++++--- tilelang/tileop/gemm/__init__.py | 65 ++++++- 9 files changed, 326 insertions(+), 66 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 1fc7578cd..eddefbd65 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1fc7578cd1ff934455b07597508b5a67d7cb5a73 +Subproject commit eddefbd65acb7b1ea51dd18068b4049754c4fa7a diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 567bc644b..8daafc0da 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -205,16 +205,14 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ICHECK(block_k % 16 == 0); if (transposed) { auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false); - auto warp_layout = base_layout->Replicate(block_m / warp_m) - ->Repeat({block_n / warp_n, 1}, true, false); + auto warp_layout = base_layout->Repeat({block_n / warp_n, 1}, true, false)->Replicate(block_m / warp_m); auto block_layout = warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false); return block_layout; } else { auto base_layout = makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false); - auto warp_layout = base_layout->Replicate(block_m / warp_m) - ->Repeat({1, block_n / warp_n}, true); + auto warp_layout = base_layout->Repeat({1, block_n / warp_n}, true)->Replicate(block_m / warp_m); auto block_layout = warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true); return block_layout; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index b20c2473e..039287460 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -12,6 +12,7 @@ #include #include "../target/utils.h" +#include "tvm/ffi/string.h" namespace tvm { namespace tl { @@ -224,9 +225,18 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { - auto stmt = Downcast( + auto prim_func = Downcast( (*f)(GetRef(this), T.target, T.thread_bounds, T.thread_var)); - return stmt; + BlockRealize block_realize = Downcast(prim_func->body); + ICHECK(prim_func->attrs.defined()); + auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.defined()); + auto block = block_realize->block; + { + BlockNode* n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, block); } else { LOG(FATAL) << "No lower function found for gemm_py"; } diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 6e3570750..2806b2533 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -675,6 +675,12 @@ class PipelineRewriter : public StmtExprMutator { } new_block = Downcast(Substitute( new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); + + Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + BlockNode* n = new_block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + if (pipeline_info_[block].async) { auto &local_state = async_states_local[stage]; local_state.producer_head = normalized_access_index; diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 714f8cef8..b095e3414 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -46,12 +46,13 @@ def main( else: T.copy(B[k * block_K, bx * block_N], B_shared) T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main -def run_gemm( +def run_gemm_ss( M, N, K, @@ -106,13 +107,13 @@ def ref_program(A, B): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) -def test_gemm(): +def test_gemm_ss(): # More test case can be found in kernel/test_tilelang_kernel_gemm.py # GEMM tests for float16 - run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) - run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) - run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) - run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) + run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) @@ -165,7 +166,6 @@ def main( T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(A_shared, A_frag) T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) - # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -228,8 +228,8 @@ def ref_program(A, B): def test_gemm_rs(): # GEMM tests for float16 - run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 0) + run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 0) def matmul_sr( @@ -280,7 +280,10 @@ def main( else: T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B_shared, B_frag) - T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + # for i, j in T.Parallel(block_N, block_K): + # B_frag[i, j] = B_shared[j, i] + # T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + T.gemm(A_shared, B_frag, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -345,21 +348,160 @@ def test_gemm_sr(): # GEMM tests for float16 run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + }) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_rr(): + # GEMM tests for float16 + run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) if __name__ == "__main__": # tilelang.testing.main() tilelang.disable_cache() - tilelang.testing.set_random_seed(42) - # run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) + # test_gemm_ss() + run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2) + # tilelang.testing.set_random_seed(42) + # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1) # print("gemm fp16 nt ss done") - # run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 nn ss done") - # run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 tn ss done") - # run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 tt ss done") - # run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128) + # exit() + + # run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0) # print("gemm fp16 nt rs done") - run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128) - # run_gemm(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128) + # run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 nn rs done") + # run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 tn rs done") + # run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0) + # print("gemm fp16 tt rs done") + + # run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32) + + # run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) + # print("gemm bf16 nn rr done") + # run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) + # print("gemm bf16 nt rr done") + # run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) + # print("gemm bf16 tn rr done") + # run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) + # print("gemm bf16 tt rr done") + + diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index b8ac49a9a..ef100ee9c 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -140,7 +140,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.IfStmtBinding()(mod) mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tilelang.transform.PipelinePlanning()(mod) + print("after pipeline planning") + print(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) + print("after inject software pipeline") + print(mod) mod = tilelang.transform.MergeIfStmt()(mod) if allow_fence_proxy(target=target): # in hopper device, wgmma is an async proxy diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index e24f4caaf..e66194bbd 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -47,18 +47,26 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): # sr represents spatial + reduction layout # the first axis is spatial while the second axis is reduction -def shared_16x16_to_mma_32x8_layout_sr(i, j): +# mma.sync matrix A layout, if wanna trans, please apply map_indices +def shared_16x16_to_mma_a_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) +def shared_16x16_to_mma_a_32x8_layout_trans(i, j): + return shared_16x16_to_mma_a_32x8_layout(j, i) -def shared_16x16_to_mma_32x8_layout_rs(i, j): - thread_id = 4 * (j % 8) + (i % 8) // 2 - return thread_id, 4 * (i // 8) + (j // 8) * 2 + (i % 2) +# mma.sync matrix B layout, if wanna trans, please apply map_indices +def shared_16x16_to_mma_b_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2) +def shared_16x16_to_mma_b_32x8_layout_trans(i, j): + return shared_16x16_to_mma_b_32x8_layout(j, i) -shared_16x16_to_mma_32x8_layout = shared_16x16_to_mma_32x8_layout_sr -shared_16x16_to_mma_32x8_layout_trans = shared_16x16_to_mma_32x8_layout_rs +shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout +shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout +shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans +shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans def shared_16x32_to_mma_32x16_layout(i, j): diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index f5a9123e9..6787a930f 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -397,41 +397,54 @@ def make_mma_load_layout(self, """ from tilelang.utils import is_fragment from tilelang.intrinsics.mma_layout import ( - shared_16x16_to_mma_32x8_layout_sr, - shared_16x16_to_mma_32x8_layout_rs, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_b, shared_16x32_to_mma_32x16_layout, shared_32x16_to_mma_32x16_layout, ) assert matrix in ["A", "B"], "matrix should be either A or B" - dtype = self.a_dtype if matrix == "A" else self.b_dtype + matrix_is_a : bool = matrix == "A" + matrix_is_b : bool = matrix == "B" + dtype = self.a_dtype if matrix_is_a else self.b_dtype dtype_bits = DataType(dtype).bits - transposed = self.a_transposed - assert transposed is False, "transposed is not supported yet" + transposed = self.a_transposed if matrix_is_a else self.b_transposed + # s represents spatial axis # r represents reduction axis # sr represents the two dims are spatial + reduction # rs represents the two dims are reduction + spatial - transform_func_sr: Callable = None - transform_func_rs: Callable = None + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None if dtype_bits == 16: - transform_func_sr = shared_16x16_to_mma_32x8_layout_sr - transform_func_rs = shared_16x16_to_mma_32x8_layout_rs + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b elif dtype_bits == 8: - transform_func_sr = shared_16x32_to_mma_32x16_layout - transform_func_rs = shared_32x16_to_mma_32x16_layout + transform_func_sr_a = shared_16x32_to_mma_32x16_layout + transform_func_sr_b = shared_32x16_to_mma_32x16_layout else: raise ValueError(f"Unsupported dtype {dtype}") + is_sr_conditions = [False] - is_sr_conditions.append(matrix == "A" and not transposed) - is_sr_conditions.append(matrix == "B" and transposed) + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) is_sr_axis_order = any(is_sr_conditions) - transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + else: + raise ValueError(f"Unsupported matrix {matrix}") assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format( local_buf.scope()) - if matrix == "A": + if matrix_is_a: micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k else: micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y @@ -440,10 +453,7 @@ def make_mma_load_layout(self, self.block_row_warps, self.block_col_warps, ) - warp_rows, warp_cols = self.warp_rows, self.warp_cols - warp_s = warp_rows if matrix == "A" else warp_cols - chunk = self.chunk - transform_func = transform_func + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") def forward_thread(i: int, j: int) -> int: @@ -465,11 +475,44 @@ def forward_index(i: int, j: int) -> int: forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) - warp_fragment = base_fragment.repeat([block_row_warps, 1], - repeat_on_thread=True).replicate(block_col_warps) - block_fragment = warp_fragment.repeat([warp_s, chunk // micro_size_r], - repeat_on_thread=False, - lower_dim_first=False) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.chunk + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // micro_size_r + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], + repeat_on_thread=False, + lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], + repeat_on_thread=False, + lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + return block_fragment def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 81d25cf1f..30ea56777 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -80,16 +80,23 @@ def infer_layout(self, target: Target, thread_nums: int): self.C: mma_emitter.make_mma_store_layout(self.C), } elif self.is_gemm_sr(): - raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + return { + self.A: make_swizzled_layout(self.A), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } elif self.is_gemm_rs(): return { - # make mma load layout or ldmatrix layout? self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), - self.C: mma_emitter.make_mma_store_layout(self.C), self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), } elif self.is_gemm_rr(): - raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } else: raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") else: @@ -159,9 +166,35 @@ def _gemm_ssr() -> None: # Simplify to optimize the index computing # Must inline let statements to simplify the analysis - return _Simplify(_gemm_ssr, inline_let=True).body + return _Simplify(_gemm_ssr, inline_let=True) elif self.is_gemm_sr(): - raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + B_local = self.B + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parrent block + return _Simplify(_gemm_srr, inline_let=True) elif self.is_gemm_rs(): A_local = self.A @T.prim_func @@ -187,9 +220,25 @@ def _gemm_rsr() -> None: # Simplify to optimize the index computing # Must inline let statements to simplify the analysis - return _Simplify(_gemm_rsr, inline_let=True).body + return _Simplify(_gemm_rsr, inline_let=True) elif self.is_gemm_rr(): - raise NotImplementedError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + A_local = self.A + B_local = self.B + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + + for ki in T.serial(0, (block_K // micro_size_k)): + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) else: raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") else: From e36740d6b8544ab6a428bbe989d84604bdf19b6b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 10 Sep 2025 00:25:30 +0800 Subject: [PATCH 05/10] Refactor GEMM layout and Python integration for improved functionality - Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations. - Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes. - Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability. - Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow. These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework. --- src/layout/gemm_layouts.cc | 4 ++-- src/op/gemm_py.cc | 21 +++++++++++++------ src/transform/inject_pipeline.cc | 11 +++++----- .../test_tilelang_tilelibrary_gemm.py | 17 +++++++-------- tilelang/engine/phase.py | 4 ---- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 8daafc0da..83f9735d9 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -205,14 +205,14 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ICHECK(block_k % 16 == 0); if (transposed) { auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false); - auto warp_layout = base_layout->Repeat({block_n / warp_n, 1}, true, false)->Replicate(block_m / warp_m); + auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({block_n / warp_n, 1}, true, false); auto block_layout = warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false); return block_layout; } else { auto base_layout = makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false); - auto warp_layout = base_layout->Repeat({1, block_n / warp_n}, true)->Replicate(block_m / warp_m); + auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true); auto block_layout = warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true); return block_layout; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 039287460..d38eb6e56 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -227,16 +227,25 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { auto prim_func = Downcast( (*f)(GetRef(this), T.target, T.thread_bounds, T.thread_var)); - BlockRealize block_realize = Downcast(prim_func->body); ICHECK(prim_func->attrs.defined()); auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); ICHECK(global_symbol.defined()); - auto block = block_realize->block; - { - BlockNode* n = block.CopyOnWrite(); - n->name_hint = global_symbol.value(); + if (prim_func->body.as()) { + BlockRealize block_realize = Downcast(prim_func->body); + auto block = block_realize->block; + { + BlockNode* n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, block); } - return BlockRealize(block_realize->iter_values, block_realize->predicate, block); + // warp with block realize node + return BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body)); } else { LOG(FATAL) << "No lower function found for gemm_py"; } diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 2806b2533..d7311cf82 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -248,7 +248,6 @@ class PipelineRewriter : public StmtExprMutator { buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); } } - ordered_stmts_.resize(pipeline_info_.size()); for (const auto &[block, anno] : pipeline_info_) { ordered_stmts_.Set(anno.order, block); @@ -676,11 +675,6 @@ class PipelineRewriter : public StmtExprMutator { new_block = Downcast(Substitute( new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); - Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - BlockNode* n = new_block.CopyOnWrite(); - n->reads = access[0]; - n->writes = access[1]; - if (pipeline_info_[block].async) { auto &local_state = async_states_local[stage]; local_state.producer_head = normalized_access_index; @@ -957,6 +951,11 @@ class PipelineInjector : private StmtExprMutator { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + BlockNode* n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + for (const auto &buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); } diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index b095e3414..0088d00d8 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -90,7 +90,6 @@ def run_gemm_ss( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): @@ -209,7 +208,6 @@ def run_gemm_rs( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): @@ -280,10 +278,7 @@ def main( else: T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B_shared, B_frag) - # for i, j in T.Parallel(block_N, block_K): - # B_frag[i, j] = B_shared[j, i] - # T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) - T.gemm(A_shared, B_frag, C_local, trans_A, trans_B) + T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -327,7 +322,6 @@ def run_gemm_sr( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): @@ -448,7 +442,6 @@ def run_gemm_rr( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): @@ -478,9 +471,13 @@ def test_gemm_rr(): # tilelang.testing.main() tilelang.disable_cache() # test_gemm_ss() - run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2) + # test_gemm_sr() + # test_gemm_rs() + # test_gemm_rr() + + # run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2) # tilelang.testing.set_random_seed(42) - # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1) + run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1) # print("gemm fp16 nt ss done") # exit() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index ef100ee9c..b8ac49a9a 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -140,11 +140,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.IfStmtBinding()(mod) mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tilelang.transform.PipelinePlanning()(mod) - print("after pipeline planning") - print(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) - print("after inject software pipeline") - print(mod) mod = tilelang.transform.MergeIfStmt()(mod) if allow_fence_proxy(target=target): # in hopper device, wgmma is an async proxy From a3f256477b4fcbed6be0d825d13b0de745556725 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 10 Sep 2025 22:14:16 +0800 Subject: [PATCH 06/10] Refactor GEMM layout and testing for improved clarity and functionality - Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations. - Improved block realization handling in `gemm_py.cc` for better assignment of global symbols. - Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity. - Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations. These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework. --- src/layout/gemm_layouts.cc | 6 +- src/op/copy.cc | 1 - src/op/gemm.cc | 9 +- src/op/gemm_py.cc | 15 +- src/target/codegen_cuda.cc | 8 +- src/transform/inject_pipeline.cc | 5 +- .../test_tilelang_tilelibrary_gemm.py | 107 +++++---- tilelang/intrinsics/mma_layout.py | 38 +++- tilelang/intrinsics/mma_macro_generator.py | 199 +++++++++------- tilelang/intrinsics/utils.py | 4 +- tilelang/profiler/__init__.py | 12 +- tilelang/tileop/gemm/__init__.py | 213 +----------------- tilelang/tileop/gemm/gemm_base.py | 131 +++++++++++ tilelang/tileop/gemm/gemm_mma.py | 207 +++++++++++++++++ tilelang/utils/tensor.py | 6 +- 15 files changed, 599 insertions(+), 362 deletions(-) create mode 100644 tilelang/tileop/gemm/gemm_base.py create mode 100644 tilelang/tileop/gemm/gemm_mma.py diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 83f9735d9..567bc644b 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -205,14 +205,16 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ICHECK(block_k % 16 == 0); if (transposed) { auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false); - auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({block_n / warp_n, 1}, true, false); + auto warp_layout = base_layout->Replicate(block_m / warp_m) + ->Repeat({block_n / warp_n, 1}, true, false); auto block_layout = warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false); return block_layout; } else { auto base_layout = makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false); - auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true); + auto warp_layout = base_layout->Replicate(block_m / warp_m) + ->Repeat({1, block_n / warp_n}, true); auto block_layout = warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true); return block_layout; diff --git a/src/op/copy.cc b/src/op/copy.cc index fc9dd0349..6797d48de 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -402,7 +402,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = pass_ctx->GetConfig(kDisableTMALower, false).value(); - auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, T.analyzer, T.buffer_oob); if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) { diff --git a/src/op/gemm.cc b/src/op/gemm.cc index cae67c936..94abc12d3 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -244,7 +244,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, int best_m = 1; int best_n = 1; float best_balance = std::numeric_limits::max(); - // Try all possible combinations that satisfy the constraints for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { int n = num_warps / m; @@ -252,6 +251,13 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, // Calculate how balanced this partition is float m_per_warp = static_cast(M) / (m * kMPerWarp); float n_per_warp = static_cast(N) / (n * kNPerWarp); + // m_per_warp and n_per_warp must be greater than 1 + if (m_per_warp < 1 || n_per_warp < 1) + continue; + // m * n must equal num_warps + if (m * n != num_warps) + continue; + float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); if (balance < best_balance) { @@ -266,7 +272,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, } else { ICHECK(0) << "Unknown GemmWarpPolicy"; } - // Store the computed values in the object's member variables this->m_warp = m_warp; this->n_warp = n_warp; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index d38eb6e56..4d1c31513 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -234,18 +234,19 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { BlockRealize block_realize = Downcast(prim_func->body); auto block = block_realize->block; { - BlockNode* n = block.CopyOnWrite(); + BlockNode *n = block.CopyOnWrite(); n->name_hint = global_symbol.value(); } - return BlockRealize(block_realize->iter_values, block_realize->predicate, block); + return BlockRealize(block_realize->iter_values, block_realize->predicate, + block); } // warp with block realize node return BlockRealize( - /*iter_values=*/Array(), - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/global_symbol.value(), prim_func->body)); + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body)); } else { LOG(FATAL) << "No lower function found for gemm_py"; } diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 7e7d1456a..789c0f6c4 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1331,16 +1331,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "}\n"; } else { std::string smem_elem_offset = this->PrintExpr(op->args[6]); - // need_cast_smem_ptr_to_int_ = true; - // this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, - // local_elem_offset, smem_ptr, - // smem_elem_offset); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); if (trans == 1) func_name += "_trans"; - // this->stream << func_name << "(" << local_ptr "" << ", " << smem_ptr << ");\n"; this->PrintIndent(); - this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset<< ", " << local_ptr << " + " << local_elem_offset << ");\n"; + this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset + << ", " << local_ptr << " + " << local_elem_offset << ");\n"; } } else if (op->op.same_as(builtin::mma_store())) { int m = Downcast(op->args[0])->value; diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index d7311cf82..162fb8c96 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -951,8 +951,9 @@ class PipelineInjector : private StmtExprMutator { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - BlockNode* n = block.CopyOnWrite(); + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + BlockNode *n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 0088d00d8..332008528 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -1,4 +1,3 @@ -from asyncio import threads from tilelang import tvm as tvm import tilelang.testing @@ -90,7 +89,9 @@ def run_gemm_ss( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - profiler = kernel.get_profiler() + + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): import torch @@ -109,11 +110,21 @@ def ref_program(A, B): def test_gemm_ss(): # More test case can be found in kernel/test_tilelang_kernel_gemm.py # GEMM tests for float16 - run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) - run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) - run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) - run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) - + run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2) + run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2) + run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2) + run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2) + # n8 test + run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + + # int8 test + run_gemm_ss(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) + + # float8 tests + run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) def matmul_rs( @@ -208,7 +219,7 @@ def run_gemm_rs( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - profiler = kernel.get_profiler() + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): import torch @@ -226,8 +237,22 @@ def ref_program(A, B): def test_gemm_rs(): # GEMM tests for float16 - run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 0) - run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 0) + run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) + + # n8 tests + run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + + # int8 tests + run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) + + # float8 tests + run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) def matmul_sr( @@ -322,7 +347,7 @@ def run_gemm_sr( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - profiler = kernel.get_profiler() + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): import torch @@ -345,6 +370,18 @@ def test_gemm_sr(): run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) + # n8 tests + run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + + # int8 tests + run_gemm_sr(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_sr(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_sr(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_sr(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2) + + # float8 tests + run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) + def matmul_rr( M, @@ -442,7 +479,7 @@ def run_gemm_rr( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - profiler = kernel.get_profiler() + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): import torch @@ -465,40 +502,20 @@ def test_gemm_rr(): run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) + # n8 tests + run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2) + run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2) + + # int8 tests + run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) + run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) + + # float8 tests + run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) if __name__ == "__main__": # tilelang.testing.main() - tilelang.disable_cache() - # test_gemm_ss() - # test_gemm_sr() - # test_gemm_rs() - # test_gemm_rr() - - # run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2) - # tilelang.testing.set_random_seed(42) - run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1) - # print("gemm fp16 nt ss done") - # exit() - - # run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 nt rs done") - # run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 nn rs done") - # run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 tn rs done") - # run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0) - # print("gemm fp16 tt rs done") - - # run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32) - - # run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) - # print("gemm bf16 nn rr done") - # run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) - # print("gemm bf16 nt rr done") - # run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) - # print("gemm bf16 tn rr done") - # run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) - # print("gemm bf16 tt rr done") - - + run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index e66194bbd..b1f5155ed 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -52,31 +52,49 @@ def shared_16x16_to_mma_a_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + def shared_16x16_to_mma_a_32x8_layout_trans(i, j): return shared_16x16_to_mma_a_32x8_layout(j, i) + # mma.sync matrix B layout, if wanna trans, please apply map_indices def shared_16x16_to_mma_b_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2) + def shared_16x16_to_mma_b_32x8_layout_trans(i, j): return shared_16x16_to_mma_b_32x8_layout(j, i) + shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans -def shared_16x32_to_mma_32x16_layout(i, j): +def shared_16x32_to_mma_a_32x16_layout(i, j): thread_id = 4 * (i % 8) + (j % 16) // 4 return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 -def shared_32x16_to_mma_32x16_layout(i, j): - thread_id = (i % 16) // 4 + 4 * (j % 8) - return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 +def shared_32x16_to_mma_a_32x16_layout_trans(i, j): + return shared_16x32_to_mma_a_32x16_layout(j, i) + + +def shared_16x32_to_mma_b_32x16_layout(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 8 * (i // 8) + (j // 16) * 4 + j % 4 + + +def shared_32x16_to_mma_b_32x16_layout_trans(i, j): + return shared_16x32_to_mma_b_32x16_layout(j, i) + + +shared_16x32_to_mma_32x16_layout_sr_a = shared_16x32_to_mma_a_32x16_layout +shared_16x32_to_mma_32x16_layout_sr_b = shared_16x32_to_mma_b_32x16_layout +shared_16x32_to_mma_32x16_layout_rs_a = shared_32x16_to_mma_a_32x16_layout_trans +shared_16x32_to_mma_32x16_layout_rs_b = shared_32x16_to_mma_b_32x16_layout_trans def mma_32x8_to_shared_16x16_layout(thread_id, local_id): @@ -85,6 +103,18 @@ def mma_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col +def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): + row = 8 * (local_id % 8 // 4) + (thread_id // 4) + col = 16 * (local_id // 8) + (thread_id % 4) * 4 + (local_id % 4) + return row, col + + +def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): + row = 8 * (local_id // 8) + (thread_id // 4) + col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4) + return row, col + + def shared_16x16_to_mma_32x8_smoothlayout(i, j): return (i * 2 + j // 8, j % 8) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 6787a930f..c5a09a906 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -9,18 +9,27 @@ get_ldmatrix_offset, ) from tilelang.utils import is_fragment +from tilelang.intrinsics.mma_layout import ( + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_b, + shared_16x32_to_mma_32x16_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_b, + mma_load_a_32x16_to_shared_16x32_layout, + mma_load_b_32x16_to_shared_16x32_layout, +) lift = convert -# TODO(lei): Add Typing for this file class TensorCoreIntrinEmitter(object): """ To eliminate Python syntax within TIR Macro. """ M_DIM = 16 - N_DIM = 16 + # use lowercase as n_dim can be dynamic + # the smallest instructions can be m16n8k16, so the n_dim can also be 8 + n_dim = 16 WARP_SIZE = 32 dtype_abbrv = { "float16": "fp16", @@ -65,13 +74,11 @@ def __init__( self.chunk = chunk self._initialize_k_dim(a_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) - self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_micro_size(self.M_DIM, self.k_dim) + self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) self._initialize_mma_prefix(self.k_dim) - self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) self._initialize_is_m_first(is_m_first) - self.warp_rows = warp_row_tiles // self.micro_size_x - self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.num_elems_per_byte = num_elems_per_byte @@ -105,9 +112,27 @@ def _initialize_mma_prefix(self, k_dim: int = 16): else: raise ValueError("Unsupported k_dim") - def _initialize_micro_size(self, m_dim: int = 16, n_dim: int = 16, k_dim: int = 16): + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + self.warp_rows = warp_row_tiles // m_dim + + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + self.micro_size_x = m_dim - self.micro_size_y = n_dim self.micro_size_k = k_dim def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): @@ -175,6 +200,8 @@ def ldmatrix_a(self, local_size_a = self.local_size_a a_dtype = self.a_dtype a_transposed = self.a_transposed + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(a_dtype).bits == 8 and a_transposed) thread_binding = self.get_thread_binding() @@ -189,22 +216,27 @@ def _warp_ldmatrix_a( stride = A_shared_buf.shape[-1] tx, _, warp_m = self.extract_thread_binding(thread_binding) trans = self.a_transposed - + for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] - T.ptx_ldmatrix( - a_dtype, - T.bool(trans), - 4, - ".b16", - A_local_buf.data, - i * local_size_a, - T.address_of(A_shared_buf_elem), - get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), - ) + if ldmatrix_available: + T.ptx_ldmatrix( + a_dtype, + T.bool(trans), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_buf_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + else: + for j in T.serial(local_size_a): + mi, mk = mma_load_a_32x16_to_shared_16x32_layout(tx, j) + A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -222,6 +254,9 @@ def ldmatrix_b(self, b_dtype = self.b_dtype b_transposed = self.b_transposed thread_binding = self.get_thread_binding() + replicate_b = (self.n_dim == 16) + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(b_dtype).bits == 8 and not b_transposed) @T.macro def _warp_ldmatrix_b( @@ -235,24 +270,34 @@ def _warp_ldmatrix_b( tx, warp_n, _ = self.extract_thread_binding(thread_binding) trans = not b_transposed - for j in T.serial(warp_cols): + for i in T.serial(warp_cols): # Assign B_shared_elem wi, wk = ( - warp_n * warp_col_tiles + j * micro_size_y, + warp_n * warp_col_tiles + i * micro_size_y, rk * chunk + ki * micro_size_k, ) - B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi] - T.ptx_ldmatrix( - b_dtype, - T.bool(trans), - 4, - ".b16", - B_local_buf.data, - j * local_size_b, - T.address_of(B_shared_buf_elem), - get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), - ) + if ldmatrix_available: + B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, + wi] + + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4 if replicate_b else 2, + ".b16", + B_local_buf.data, + i * local_size_b, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + ) + + else: + # load 16x32 data from shared buffer to local buffer + # must be transposed. + for j in T.serial(local_size_b): + mi, mk = mma_load_b_32x16_to_shared_16x32_layout(tx, j) + B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -271,6 +316,7 @@ def mma(self, accum_dtype = self.accum_dtype accum_dtype_abbrv = self.accum_dtype_abbrv mma_prefix = self.mma_prefix + replicate_b = (self.n_dim == 16) a_is_fragment = is_fragment(A_local_buf) b_is_fragment = is_fragment(B_local_buf) @@ -296,23 +342,24 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): i * warp_cols * local_size_out + j * local_size_out, T.bool(False), # saturate ) - - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - a_local_stride + i * local_size_a, - B_local_buf.data, - b_local_stride + j * local_size_b + lift(local_size_b) // 2, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), # saturate - ) + if replicate_b: + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), # saturate + ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) @@ -326,7 +373,7 @@ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): is_global = pid_m is not None and pid_n is not None BLOCK_M = block_row_warps * warp_rows BLOCK_N = block_col_warps * warp_cols - M_DIM, N_DIM = self.M_DIM, self.N_DIM + M_DIM, n_dim = self.M_DIM, self.n_dim C_buf_dims = len(C_buf.shape) assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" @@ -346,7 +393,7 @@ def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): row, col = T.meta_var(mma_store_index_map(tx, local_id)) if C_buf_dims == 2: C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * N_DIM + + (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[i * (warp_cols * local_size_out) + j * local_size_out + local_id] else: @@ -364,7 +411,7 @@ def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): row, col = T.meta_var(mma_store_index_map(tx, local_id)) C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, - (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col, + (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] @@ -396,15 +443,9 @@ def make_mma_load_layout(self, If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment - from tilelang.intrinsics.mma_layout import ( - shared_16x16_to_mma_32x8_layout_sr_a, - shared_16x16_to_mma_32x8_layout_sr_b, - shared_16x32_to_mma_32x16_layout, - shared_32x16_to_mma_32x16_layout, - ) assert matrix in ["A", "B"], "matrix should be either A or B" - matrix_is_a : bool = matrix == "A" - matrix_is_b : bool = matrix == "B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" dtype = self.a_dtype if matrix_is_a else self.b_dtype dtype_bits = DataType(dtype).bits transposed = self.a_transposed if matrix_is_a else self.b_transposed @@ -421,8 +462,8 @@ def make_mma_load_layout(self, transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b elif dtype_bits == 8: - transform_func_sr_a = shared_16x32_to_mma_32x16_layout - transform_func_sr_b = shared_32x16_to_mma_32x16_layout + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b else: raise ValueError(f"Unsupported dtype {dtype}") @@ -435,9 +476,11 @@ def make_mma_load_layout(self, # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( + j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( + j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -471,11 +514,11 @@ def forward_index(i: int, j: int) -> int: return local_id base_fragment = T.Fragment( - [micro_size_r, micro_size_s], + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) - + warp_rows, warp_cols = self.warp_rows, self.warp_cols chunk = self.chunk @@ -486,30 +529,30 @@ def forward_index(i: int, j: int) -> int: if is_sr_axis_order: warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + repeat_on_thread=False, + lower_dim_first=False) if matrix_is_a: block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) elif matrix_is_b: block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + repeat_on_thread=True, + lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + repeat_on_thread=False, + lower_dim_first=True) if matrix_is_a: block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) elif matrix_is_b: block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + repeat_on_thread=True, + lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 13d6c63f2..08730a40a 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -4,7 +4,7 @@ ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_16x32_to_shared_16x32_layout_a, - ldmatrix_16x32_to_shared_16x32_layout_b, + ldmatrix_32x16_to_shared_16x32_layout_b, mma_store_32x8_to_shared_16x16_layout, ) from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) @@ -37,7 +37,7 @@ def get_ldmatrix_offset( return new_row_idx * stride + new_col_idx elif dtype_bits == 8: if matrix == "B" and transposed: - transform_func = ldmatrix_16x32_to_shared_16x32_layout_b + transform_func = ldmatrix_32x16_to_shared_16x32_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx elif matrix == "A" and not transposed: diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index d63c4db1f..55391cea1 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -126,9 +126,17 @@ def assert_allclose( if lhs is not None and rhs is not None: # in case of numsplit template, the ref output may be None # which means the value is invalid, so we skip the comparison + def is_float8(tensor: torch.Tensor) -> bool: + return tensor.dtype in { + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + } + torch_assert_close( - lhs, - rhs, + lhs if not is_float8(lhs) else lhs.to(torch.float32), + rhs if not is_float8(rhs) else rhs.to(torch.float32), rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio, diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 30ea56777..1c8ca8652 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -2,17 +2,13 @@ from tvm import tir from tilelang.utils.target import ( target_is_cuda,) -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) -from tilelang.layout import make_swizzled_layout -from tilelang import language as T from tvm.target import Target from tvm.ir.base import Node from tvm.runtime import Scriptable import tvm.ffi from tilelang.ir import GemmWarpPolicy -from tilelang.transform.simplify import _Simplify -from tilelang.utils.language import is_shared, is_fragment +from .gemm_mma import GemmMMA + @tvm.ffi.register_func("tl.gemm_py.infer_layout") def gemm_py_infer_layout(gemm_py, target, thread_bounds): @@ -56,49 +52,7 @@ class GemmPy(Node, Scriptable): def infer_layout(self, target: Target, thread_nums: int): if target_is_cuda(target): # TODO(lei): Support more cuda architectures, now mma only - # Now only implement ssr layout - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=self.chunk, - ) - if self.is_gemm_ss(): - return { - self.A: make_swizzled_layout(self.A), - self.B: make_swizzled_layout(self.B), - self.C: mma_emitter.make_mma_store_layout(self.C), - } - elif self.is_gemm_sr(): - return { - self.A: make_swizzled_layout(self.A), - self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), - self.C: mma_emitter.make_mma_store_layout(self.C), - } - elif self.is_gemm_rs(): - return { - self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), - self.B: make_swizzled_layout(self.B), - self.C: mma_emitter.make_mma_store_layout(self.C), - } - elif self.is_gemm_rr(): - return { - self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), - self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), - self.C: mma_emitter.make_mma_store_layout(self.C), - } - else: - raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + return GemmMMA(self).infer_layout(target, thread_nums) else: raise ValueError(f"Unsupported target: {target}") @@ -106,165 +60,6 @@ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): if target_is_cuda(target): # TODO(lei): Support more cuda architectures, now mma only # Now only implement ssr layout - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=self.chunk, - thread_var=thread_var, - ) - - in_dtype = self.in_dtype - warp_rows = mma_emitter.warp_rows - warp_cols = mma_emitter.warp_cols - local_size_a = mma_emitter.local_size_a - local_size_b = mma_emitter.local_size_b - block_K = mma_emitter.chunk - micro_size_k = mma_emitter.micro_size_k - A_shared = self.A - B_shared = self.B - C_local = self.C - - if self.is_gemm_ss(): - @T.prim_func - def _gemm_ssr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local, ki) - - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_ssr, inline_let=True) - elif self.is_gemm_sr(): - B_local = self.B - @T.prim_func - def _gemm_srr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - - for ki in T.serial(0, (block_K // micro_size_k)): - - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local, ki) - - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - # alloc_buffers body - # insert into parrent block - return _Simplify(_gemm_srr, inline_let=True) - elif self.is_gemm_rs(): - A_local = self.A - @T.prim_func - def _gemm_rsr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - - for ki in T.serial(0, (block_K // micro_size_k)): - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local, ki) - - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_rsr, inline_let=True) - elif self.is_gemm_rr(): - A_local = self.A - B_local = self.B - @T.prim_func - def _gemm_rsr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - - for ki in T.serial(0, (block_K // micro_size_k)): - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local, ki) - - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_rsr, inline_let=True) - else: - raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + return GemmMMA(self).lower(target, thread_nums, thread_var) else: raise ValueError(f"Unsupported target: {target}") - - @property - def in_dtype(self) -> str: - assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" - return self.A.dtype - - @property - def accum_dtype(self) -> str: - return self.C.dtype - - @property - def chunk(self) -> int: - return self.A.shape[-2] if self.trans_A else self.A.shape[-1] - - def is_gemm_ss(self) -> bool: - return is_shared(self.A) and is_shared(self.B) - - def is_gemm_sr(self) -> bool: - return is_shared(self.A) and is_fragment(self.B) - - def is_gemm_rs(self) -> bool: - return is_fragment(self.A) and is_shared(self.B) - - def is_gemm_rr(self) -> bool: - return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py new file mode 100644 index 000000000..e9e9490d9 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_base.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tvm.tir import Var +from tilelang.utils.language import is_shared, is_fragment +from tilelang.ir import GemmWarpPolicy + +@dataclass +class GemmBase(object): + gemm_node: "GemmPy" + + def infer_layout(self, target: Target, thread_nums: int): + raise NotImplementedError("infer_layout is not implemented") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + raise NotImplementedError("lower is not implemented") + + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) + + @property + def M(self) -> int: + return self.gemm_node.M + + @property + def N(self) -> int: + return self.gemm_node.N + + @property + def K(self) -> int: + return self.gemm_node.K + + @property + def trans_A(self) -> bool: + return self.gemm_node.trans_A + + @property + def trans_B(self) -> bool: + return self.gemm_node.trans_B + + @property + def in_dtype(self) -> str: + return self.gemm_node.in_dtype + + @property + def accum_dtype(self) -> str: + return self.gemm_node.accum_dtype + + @property + def chunk(self) -> int: + return self.gemm_node.chunk + + @property + def in_dtype(self) -> str: + assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" + return self.A.dtype + + @property + def accum_dtype(self) -> str: + return self.C.dtype + + @property + def chunk(self) -> int: + return self.A.shape[-2] if self.trans_A else self.A.shape[-1] + + @property + def A(self) -> tir.Buffer: + return self.gemm_node.A + + @property + def B(self) -> tir.Buffer: + return self.gemm_node.B + + @property + def C(self) -> tir.Buffer: + return self.gemm_node.C + + @property + def APtr(self) -> tir.PrimExpr: + return self.gemm_node.APtr + + @property + def BPtr(self) -> tir.PrimExpr: + return self.gemm_node.BPtr + + @property + def CPtr(self) -> tir.PrimExpr: + return self.gemm_node.CPtr + + @property + def stride_A(self) -> int: + return self.gemm_node.stride_A + + @property + def stride_B(self) -> int: + return self.gemm_node.stride_B + + @property + def offset_A(self) -> int: + return self.gemm_node.offset_A + + @property + def offset_B(self) -> int: + return self.gemm_node.offset_B + + @property + def clear_accum(self) -> bool: + return self.gemm_node.clear_accum + + @property + def k_pack(self) -> int: + return self.gemm_node.k_pack + + @property + def wg_wait(self) -> int: + return self.gemm_node.wg_wait + + @property + def policy(self) -> GemmWarpPolicy: + return self.gemm_node.policy diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py new file mode 100644 index 000000000..fa82a2034 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -0,0 +1,207 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter,) +from tilelang.utils.language import is_shared, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMMA(GemmBase): + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + block_K = mma_emitter.chunk + micro_size_k = mma_emitter.micro_size_k + A_shared = self.A + B_shared = self.B + C_local = self.C + + if self.is_gemm_ss(): + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + B_local = self.B + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parrent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + A_local = self.A + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + A_local = self.A + B_local = self.B + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + + for ki in T.serial(0, (block_K // micro_size_k)): + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index bab967a85..07a34cc44 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -113,9 +113,11 @@ def get_tensor(param: KernelParam) -> torch.Tensor: else: return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) elif supply_type == TensorSupplyType.Uniform: - return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) + return torch.empty( + *shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) elif supply_type == TensorSupplyType.Normal: - return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) + return torch.empty( + *shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype) elif supply_type == TensorSupplyType.Randn: return torch.randn(*shape, device=device).to(dtype) elif supply_type == TensorSupplyType.Zero: From ded566e51307f8015167af4eef173c7bad8ee60f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Sep 2025 01:15:11 +0800 Subject: [PATCH 07/10] tfloat32 support. --- CMakeLists.txt | 1 + src/target/codegen_cuda.cc | 3 +- src/target/ptx.cc | 864 ++++++++++++++++++ src/target/ptx.h | 164 ++++ .../test_tilelang_tilelibrary_gemm.py | 39 +- tilelang/intrinsics/mma_layout.py | 53 +- tilelang/intrinsics/mma_macro_generator.py | 41 +- tilelang/intrinsics/utils.py | 19 +- tilelang/tileop/gemm/gemm_base.py | 60 +- tilelang/tileop/gemm/gemm_mma.py | 23 +- 10 files changed, 1200 insertions(+), 67 deletions(-) create mode 100644 src/target/ptx.cc create mode 100644 src/target/ptx.h diff --git a/CMakeLists.txt b/CMakeLists.txt index b780ae2e7..a54b6f5ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,6 +132,7 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS if(USE_CUDA) tilelang_file_glob(GLOB TILE_LANG_CUDA_SRCS src/runtime/*.cc + src/target/ptx.cc src/target/codegen_cuda.cc src/target/rt_mod_cuda.cc ) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 789c0f6c4..a2f58b67b 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -15,10 +15,11 @@ #include "../op/builtin.h" #include "arith/pattern_match.h" -#include "target/source/ptx.h" +#include "./ptx.h" namespace tvm { namespace codegen { +using namespace tvm::tl::codegen; static std::string GetFP8Type(DataType type) { std::stringstream stream; diff --git a/src/target/ptx.cc b/src/target/ptx.cc new file mode 100644 index 000000000..f872cad0b --- /dev/null +++ b/src/target/ptx.cc @@ -0,0 +1,864 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx.cc + */ + +#include "ptx.h" + +#include +#include +#include +#include +#include + +namespace tvm::tl { +namespace codegen { + +// PTX related data structures and functions. +namespace ptx { + +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 +}; + +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", + ".s32", ".u32", ".s64", ".u64", ".e4m3", ".e5m2", + ".f16", ".bf16", ".f16x2", ".f32", ".tf32", ".f64", + ".b1", ".b8", ".b16", ".b32", ".b64"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, + 16, 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; + +/*! + * \brief Create PTX data type from string. + */ +inline DataType DTypeFromString(const std::string str) { + if (str == "int4" || str == ".s4") { + return DataType::kInt4; + } else if (str == "uint4" || str == ".u4") { + return DataType::kUInt4; + } else if (str == "int8" || str == ".s8") { + return DataType::kInt8; + } else if (str == "uint8" || str == ".u8") { + return DataType::kUInt8; + } else if (str == "int16" || str == ".s16") { + return DataType::kInt16; + } else if (str == "uint16" || str == ".u16") { + return DataType::kUInt16; + } else if (str == "int32" || str == ".s32") { + return DataType::kInt32; + } else if (str == "uint32" || str == ".u32") { + return DataType::kUInt32; + } else if (str == "int64" || str == ".s64") { + return DataType::kInt64; + } else if (str == "uint64" || str == ".u64") { + return DataType::kUInt64; + } else if (str == "e4m3" || str == ".e4m3") { + return DataType::kFloat8_e4m3; + } else if (str == "e5m2" || str == ".e5m2") { + return DataType::kFloat8_e5m2; + } else if (str == "float16" || str == "fp16" || str == ".f16") { + return DataType::kFloat16; + } else if (str == "bfloat16" || str == "bf16") { + return DataType::kBFloat16; + } else if (str == ".f16x2") { + return DataType::kFloat16x2; + } else if (str == "float32" || str == "fp32" || str == ".f32") { + return DataType::kFloat32; + } else if (str == "tf32") { + return DataType::kTensorFloat32; + } else if (str == "float64" || str == "fp64" || str == ".f64") { + return DataType::kFloat64; + } else if (str == "int1" || str == ".b1") { + return DataType::kBit1; + } else if (str == ".b8") { + return DataType::kBit8; + } else if (str == ".b16") { + return DataType::kBit16; + } else if (str == ".b32") { + return DataType::kBit32; + } else if (str == ".b64") { + return DataType::kBit64; + } else { + LOG(FATAL) << "Unrecognized PTX data type " << str; + } +} + +/*! + * \brief Get the string representation of given PTX data type. + */ +inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } + +/*! + * \brief Get the number of bits of given PTX data type. + */ +inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } + +/*! + * \brief Extract the value m, n, k from string m*n*k* + */ +inline std::tuple ParseMMAShape(const std::string& str) { + size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); + CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) + << "Cannot parse MMA shape " << str; + int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), + n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); + return std::make_tuple(m, n, k); +} + +/*! + * \brief Layout Type + */ +enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 }; + +/*! + * \brief Parse layout type + */ +LayoutType LayoutTypeFromString(const std::string& str) { + if (str == "row") { + return LayoutType::kRowMajor; + } else if (str == "col") { + return LayoutType::kColumnMajor; + } else { + LOG(FATAL) << "Unrecognized layout type " << str; + } +} + +static const char* layout_type_str[] = {"row", "col"}; + +/*! + * \brief Convert layout type to string. + */ +inline std::string LayoutTypeToString(LayoutType layout) { + return layout_type_str[static_cast(layout)]; +} + +/*! + * \brief MMA Configurations, used to determine validity. + */ +struct MMAConfig { + explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse) + : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {} + int m, n, k; + DataType dtype_mul; + bool use_bit_op; + bool sparse; + inline bool operator==(const MMAConfig& other) { + return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul && + use_bit_op == other.use_bit_op && sparse == other.sparse; + } +}; + +/*! + * \brief Valid MMA configurations + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape + */ +const MMAConfig valid_mma_configs[] = { + MMAConfig(8, 8, 4, DataType::kFloat64, false, false), + MMAConfig(8, 8, 4, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 4, DataType::kFloat32, false, false), + MMAConfig(16, 8, 8, DataType::kFloat32, false, false), + MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false), + MMAConfig(8, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 32, DataType::kInt8, false, false), + MMAConfig(8, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 32, DataType::kUInt8, false, false), + MMAConfig(8, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 64, DataType::kInt4, false, false), + MMAConfig(8, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 64, DataType::kUInt4, false, false), + MMAConfig(8, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 256, DataType::kBit1, true, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kFloat16, false, true), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 32, DataType::kInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt8, false, true), + MMAConfig(16, 8, 32, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt4, false, true), + MMAConfig(16, 8, 128, DataType::kInt4, false, true), + MMAConfig(16, 8, 64, DataType::kUInt4, false, true), + MMAConfig(16, 8, 128, DataType::kUInt4, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e4m3, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e4m3, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e5m2, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e5m2, false, true), +}; + +/*! + * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA + * computation. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { + std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + + DTypeToString(dtype_b) + " do not match."; + // check a and b + switch (dtype_a) { + case DataType::kBit1: + case DataType::kFloat16: + case DataType::kBFloat16: + case DataType::kFloat32: + case DataType::kTensorFloat32: + case DataType::kFloat64: + CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + break; + case DataType::kInt4: + case DataType::kUInt4: + CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; + break; + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; + break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_b == DataType::kFloat8_e4m3 || dtype_b == DataType::kFloat8_e5m2) + << ab_not_match_err_str; + break; + default: + CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b); + } + // check a,b and c + switch (dtype_a) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_c == DataType::kInt32) + << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << ", accumulator data type should be s32."; + break; + case DataType::kFloat16: + CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) + << "For multiplicand data type f16, accumulator data type should be f16/f32."; + break; + case DataType::kBFloat16: + case DataType::kFloat32: + case DataType::kTensorFloat32: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can only be f32."; + break; + case DataType::kFloat64: + CHECK(dtype_c == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be f64."; + break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type e4m3/e5m2, accumulator data type can only be f32."; + break; + default: + CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; + } +} + +/*! + * \brief Check whether the given configuration is valid for MMA computation. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param layout_a The layout of multiplicand A (row/col). + * \param layout_b The layout of multiplicand B (row/col). + * \param dtype_a The data type of multiplicand A. + * \param dtype_b The data type of multiplicand B. + * \param dtype_c The data type of accumulator C. + * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not + * 1-bit MMA). + * \param sparse Whether it's Sparse MMA or not. + * \param saturate Whether saturate output or not. + */ +void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b, + DataType dtype_a, DataType dtype_b, DataType dtype_c, + const std::string& bit_op, bool sparse, bool saturate) { + CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "") + << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; + bool use_bit_op = !bit_op.empty(); + if (use_bit_op) { + CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand."; + } + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); + if (saturate) { + CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || + dtype_a == DataType::kUInt8) + << "Output saturation only applicable to multiplicand type s4/u4/s8/u8."; + } + + if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { + // Only MMA on m8n8k4 for fp16 supports customized layouts. + CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) + << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," + << LayoutTypeToString(layout_b) << "."; + } + + MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse); + bool match = false; + for (const MMAConfig& valid_config : valid_mma_configs) { + if (config == valid_config) { + match = true; + break; + } + } + CHECK(match) << "Cannot find matched MMA configurations."; +} + +/*! + * \brief Fragment attributes + */ +class FragAttrs { + public: + explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type) + : reg_type(reg_type), size(size), ptr_type(ptr_type) {} + /*! \brief PTX register type */ + char reg_type; + /*! \brief Fragment size */ + uint32_t size; + /*! \brief Fragment pointer type */ + std::string ptr_type; +}; + +/*! + * \brief Fragment attributes of given data type. + */ +inline FragAttrs GetFragAttrs(DataType dtype) { + switch (dtype) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + case DataType::kBit16: + case DataType::kFloat16: // .f16x2 register + case DataType::kBFloat16: + case DataType::kTensorFloat32: + return FragAttrs('r', 32, "(unsigned *)"); + case DataType::kInt32: + return FragAttrs('r', 32, "(int *)"); + case DataType::kFloat32: + return FragAttrs('f', 32, "(float *)"); + case DataType::kFloat64: + return FragAttrs('d', 64, "(double *)"); + default: + ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; + return FragAttrs('\0', 0, ""); + } +} + +}; // namespace ptx + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { + public: + void register_rule(const std::string& pattern, const std::string& replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto&& rule : _rules) { + auto [pattern, replacement] = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } + } + return str; + } + void empty_rules() { _rules.clear(); } + + private: + std::vector> _rules; +}; + +/*! + * \brief Get the number of MMA computations for given shape and datatype. + */ +inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) { + if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) { + // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one. + return 4; + } else { + return 1; + } +} + +/*! + * \brief Return template string, input operands string and output operands string. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. + * \param sparse Whether it's Sparse MMA or not. + */ +inline std::tuple GetMMAOperands(int m, int n, int k, + ptx::DataType dtype_a, + ptx::DataType dtype_b, + ptx::DataType dtype_c, + bool sparse) { + std::stringstream templates, inputs, outputs; + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); + constexpr uint32_t warp_size = 32; + const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = + (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, + num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_b; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}"; + // templates of metadata and sparse selector for sparse mma. + if (sparse) { + templates << ", %" << (arg_counter++) << ", F"; + } + + // generate inputs + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; + } + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i + << "])"; + } + for (int i = 0; i < num_operands_b; ++i) { + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i + << "])"; + } + for (int i = 0; i < num_operands_c; ++i) { + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i + << "])"; + } + // input of metadata for sparse mma. + if (sparse) { + inputs << ", \"r\"(((unsigned *)(E))[0])"; + } + + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i + << "])"; + } + return std::make_tuple(templates.str(), inputs.str(), outputs.str()); +} + +std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, + const std::string& B_layout, const std::string& A_dtype, + const std::string& B_dtype, const std::string& C_dtype, + const std::string& a_ptr, const std::string& a_elem_offset, + const std::string& b_ptr, const std::string& b_elem_offset, + const std::string& c_ptr, const std::string& c_elem_offset, + const std::string& metadata, const std::string& metadata_offset, + const std::string& sparsity_selector, const std::string& bit_op, + bool sparse, bool saturate) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + if (dtype_a == ptx::DataType::kFloat32) { + dtype_a = ptx::DataType::kTensorFloat32; + } + if (dtype_b == ptx::DataType::kFloat32) { + dtype_b = ptx::DataType::kTensorFloat32; + } + ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), + layout_b = ptx::LayoutTypeFromString(B_layout); + auto [m, n, k] = ptx::ParseMMAShape(shape); + CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse, + saturate); + std::string asm_code = R"( + { + __asm__ __volatile__( + "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}" + "{templates};\n" + : {outputs} + : {inputs}); + } +)"; + auto [templates_str, inputs_str, outputs_str] = + GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); + + // replace patterns + Replacer replacer; + replacer.register_rule("{.sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{.shape}", "." + shape); + replacer.register_rule("{.saturate}", saturate ? ".satfinite" : ""); + replacer.register_rule("{.alayout}", "." + A_layout); + replacer.register_rule("{.blayout}", "." + B_layout); + replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + replacer.register_rule("A", a_ptr + " + " + a_elem_offset); + replacer.register_rule("B", b_ptr + " + " + b_elem_offset); + replacer.register_rule("C", c_ptr + " + " + c_elem_offset); + replacer.register_rule("D", c_ptr + " + " + c_elem_offset); + replacer.register_rule("E", metadata + " + " + metadata_offset); + replacer.register_rule("F", sparsity_selector); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +inline std::tuple GetLoadMatrixOperands( + int num, const std::string& local_ptr, const std::string& local_elem_offset) { + std::stringstream templates, outputs; + int arg_counter = 0; + // generate templates + templates << "{%" << arg_counter++; + for (int i = 1; i < num; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, [%" << arg_counter++ << "]"; + // generate outputs + std::string ptr_type = "(unsigned *)"; + for (int i = 0; i < num; ++i) { + if (i != 0) { + outputs << ", "; + } + outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))[" + << i << "])"; + } + return std::make_tuple(templates.str(), outputs.str()); +} + +std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, + const std::string& local_ptr, + const std::string& local_elem_offset, + const std::string& smem_ptr, + const std::string& smem_elem_offset) { + CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices."; + ptx::DataType data_type = ptx::DTypeFromString(type); + CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16."; + std::string asm_code = R"( + { + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); + __asm__ __volatile__( + "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}" + "{templates};\n" + : {outputs} + : "r"(addr) + ); + } +)"; + auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset); + + Replacer replacer; + replacer.register_rule("{.shape}", ".m8n8"); + replacer.register_rule("{.num}", ".x" + std::to_string(num)); + replacer.register_rule("{.trans}", trans ? ".trans" : ""); + replacer.register_rule("{.ss}", ".shared"); + replacer.register_rule("{.type}", ptx::DTypeToString(data_type)); + replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes) { + std::string asm_code = R"( + { + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value) { + CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || + bytes == "1") + << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; + std::string predicated_asm_code = R"( + { + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); + int pred_guard = (int){pred_guard}; + __asm__ __volatile__( + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" + #endif + " @!p {store_shared};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg} + ); + } +)"; + auto [store_shared, nopreg] = [](const std::string& bytes) { + if (bytes == "16") + return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}", + "\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)"); + else if (bytes == "12") + return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)"); + else if (bytes == "8") + return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)"); + else if (bytes == "4") + return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "2") + return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "1") + return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)"); + else + return std::make_tuple("", ""); + }(bytes); + + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{store_shared}", store_shared); + replacer.register_rule("{nopreg}", nopreg); + replacer.register_rule("{pred_guard}", predicate_value); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes, + const std::string& barrier) { + std::string asm_code = R"( + { + unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" + :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int) + : "memory" + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{barrier}", "&" + barrier); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "cp.async.mbarrier.arrive.shared.b64 [%0];" + :: "r" (barrier_addr_int) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, + const std::string& thread_count) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + int thread_count = {thread_count}; + __asm__ __volatile__( + "mbarrier.init.shared.b64 [%0], %1;" + :: "r"(barrier_addr_int), "r"(thread_count) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + replacer.register_rule("{thread_count}", thread_count); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintArriveBarrierAsm(const std::string& barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }" + :: "r"(barrier_addr_int) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, + const std::string& byte_count) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + int byte_count = {byte_count}; + __asm__ __volatile__( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + :: "r"(barrier_addr_int), "r"(byte_count) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + replacer.register_rule("{byte_count}", byte_count); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintWaitBarrierAsm(const std::string& barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + constexpr int phase_bit = 0; + __asm__ __volatile__( + "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }" + :: "r"(barrier_addr_int), "r"(phase_bit) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +} // namespace codegen +} // namespace tvm::tl diff --git a/src/target/ptx.h b/src/target/ptx.h new file mode 100644 index 000000000..72691fd44 --- /dev/null +++ b/src/target/ptx.h @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx.h + * \brief Code generation with inlined PTX code. + */ +#ifndef TVM_TL_TARGET_SOURCE_PTX_H_ +#define TVM_TL_TARGET_SOURCE_PTX_H_ + +#include + +#include +#include + +namespace tvm::tl { +namespace codegen { + +/*! + * \brief Print MMA assembly string given parameters. + * \param shape The shape string mMnNkK + * \param A_layout The layout of multiplicand A, can be either "row" or "col". + * \param B_layout The layout of multiplicand B, can be either "row" or "col". + * \param A_dtype The data type of multiplicand A. + * \param B_dtype The data type of multiplicand B. + * \param C_dtype The data type of multiplicand C. + * \param a_ptr Pointer to buffer A. + * \param a_offset The offset of element in A. + * \param b_ptr Pointer to buffer B. + * \param b_offset The offset of element in B. + * \param c_ptr Pointer to buffer C. + * \param c_offset The offset of element in C. + * \param metadata Pointer to metadata buffer (only used for sparse mma). + * \param metadata_offset The offset of element in metadata. + * \param sparsity_selector The sparsity selector in sparse mma. + * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or "and". + * \param sparse Whether it's sparse mma or not. + * \param saturate Whether saturate output or not. + */ +std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, + const std::string& B_layout, const std::string& A_dtype, + const std::string& B_dtype, const std::string& C_dtype, + const std::string& a_ptr, const std::string& a_offset, + const std::string& b_ptr, const std::string& b_offset, + const std::string& c_ptr, const std::string& c_offset, + const std::string& metadata, const std::string& metadata_offset, + const std::string& sparsity_selector, const std::string& bit_op, + bool sparse, bool saturate); + +/*! + * \brief Print ldmatrix assembly string given parameters. + * \param trans: whether the matrix is loaded in column major format or not. + * \param num: number of matrices to load. + * \param type: The data type in the matrix, .b16 is the only accepted data type. + * \param local_ptr: pointer to local buffer. + * \param local_elem_offset: The offset of the element to store in the local buffer. + * \param smem_ptr: pointer to the shared memory buffer to load. + * \param smem_elem_offset: The offset of the start element of the row to load in shared memory. + */ +std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, + const std::string& local_ptr, + const std::string& local_elem_offset, + const std::string& smem_ptr, + const std::string& smem_elem_offset); + +/*! + * \brief Print ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + */ +std::string PrintCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes); + +/*! + * \brief Print predicated ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + * \param predicate_value: The value of predicate `@p`. + */ +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value); + +/*! + * \brief Print ptx async copy from global to shared memory using cp.async.bulk + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy. + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes, + const std::string& barrier); + +/*! + * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintCpAsyncBarrierAsm(const std::string& barrier); + +/*! + * \brief Print ptx barrier initialization of thread count using mbarrier.init + * \param barrier: The name of the barrier in shared memory. + * \param thread_count: The number of threads expected to arrive at the barrier. + */ +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, + const std::string& thread_count); + +/*! + * \brief Print ptx barrier arrival using mbarrier.arrive + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintArriveBarrierAsm(const std::string& barrier); + +/*! + * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx + * \param barrier: The name of the barrier in shared memory. + * \param byte_count: Increases the tx count of the mbarrier object to track completion of + * addtional async transactions. + */ +std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, + const std::string& byte_count); + +/*! + * \brief Print ptx barrier wait using mbarrier.try_wait + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintWaitBarrierAsm(const std::string& barrier); + +} // namespace codegen +} // namespace tvm::tl + +#endif // TVM_TL_TARGET_SOURCE_PTX_H_ diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 332008528..881c975e0 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -126,6 +126,11 @@ def test_gemm_ss(): # float8 tests run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) + # tfloat32 test + run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) def matmul_rs( M, @@ -254,6 +259,11 @@ def test_gemm_rs(): # float8 tests run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) + # float32 tests + run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) def matmul_sr( M, @@ -382,6 +392,12 @@ def test_gemm_sr(): # float8 tests run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) + # float32 tests + run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + def matmul_rr( M, @@ -515,7 +531,28 @@ def test_gemm_rr(): # float8 tests run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) + # float32 tests + run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) if __name__ == "__main__": # tilelang.testing.main() - run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) + # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float32", 128, 128, 32, 0) + # tilelang.disable_cache() + + run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 0) + + run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 0) + + run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 0) + run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 0) diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index b1f5155ed..f1b3d3d82 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -3,30 +3,30 @@ import tilelang.language as T -def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): +def ldmatrix_32x4_to_shared_16x8_layout_a(thread_id, local_id): row = thread_id % 16 - col = 8 * (thread_id // 16) + local_id % 8 + col = (thread_id // 16) * 4 + local_id % 4 return row, col - -def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): - row = 8 * (thread_id // 16) + (thread_id % 8) - col = 8 * ((thread_id % 16) // 8) + local_id % 8 +def ldmatrix_32x4_to_shared_16x8_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = ((thread_id % 16) // 8) * 4 + local_id % 4 return row, col -def ldmatrix_16x32_to_shared_16x32_layout_a(thread_id, local_id): +def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): row = thread_id % 16 - col = 16 * (thread_id // 16) + local_id % 16 + col = 8 * (thread_id // 16) + local_id % 8 return row, col -def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id): +def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): row = 8 * (thread_id // 16) + (thread_id % 8) - col = 16 * ((thread_id % 16) // 8) + local_id % 16 + col = 8 * ((thread_id % 16) // 8) + local_id % 8 return row, col + def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): row = thread_id % 16 col = local_id + (thread_id // 16) * 16 @@ -48,6 +48,26 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): # sr represents spatial + reduction layout # the first axis is spatial while the second axis is reduction # mma.sync matrix A layout, if wanna trans, please apply map_indices +def shared_16x8_to_mma_a_32x4_layout(i, j): + thread_id = 4 * (i % 8) + (j % 4) + return thread_id, 2 * (j // 4) + (i // 8) + +def shared_16x8_to_mma_a_32x4_layout_trans(i, j): + return shared_16x8_to_mma_a_32x4_layout(j, i) + +# mma.sync matrix B layout, if wanna trans, please apply map_indices +def shared_16x8_to_mma_b_32x4_layout(i, j): + thread_id = 4 * (i % 8) + (j % 4) + return thread_id, 2 * (i // 8) + (j // 4) + +def shared_16x8_to_mma_b_32x4_layout_trans(i, j): + return shared_16x8_to_mma_b_32x4_layout(j, i) + +shared_16x8_to_mma_32x4_layout_sr_a = shared_16x8_to_mma_a_32x4_layout +shared_16x8_to_mma_32x4_layout_sr_b = shared_16x8_to_mma_b_32x4_layout +shared_16x8_to_mma_32x4_layout_rs_a = shared_16x8_to_mma_a_32x4_layout_trans +shared_16x8_to_mma_32x4_layout_rs_b = shared_16x8_to_mma_b_32x4_layout_trans + def shared_16x16_to_mma_a_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) @@ -57,7 +77,6 @@ def shared_16x16_to_mma_a_32x8_layout_trans(i, j): return shared_16x16_to_mma_a_32x8_layout(j, i) -# mma.sync matrix B layout, if wanna trans, please apply map_indices def shared_16x16_to_mma_b_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2) @@ -103,6 +122,18 @@ def mma_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col +def mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id): + row = 8 * (local_id % 2) + (thread_id // 4) + col = 4 * (local_id // 2) + (thread_id % 4) + return row, col + + +def mma_load_b_32x4_to_shared_16x8_layout(thread_id, local_id): + row = 8 * (local_id // 2) + (thread_id // 4) + col = 4 * (local_id % 2) + (thread_id % 4) + return row, col + + def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): row = 8 * (local_id % 8 // 4) + (thread_id // 4) col = 16 * (local_id // 8) + (thread_id % 4) * 4 + (local_id % 4) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index c5a09a906..c4cd0470d 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -10,10 +10,14 @@ ) from tilelang.utils import is_fragment from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x8_to_mma_32x4_layout_sr_b, shared_16x16_to_mma_32x8_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_b, shared_16x32_to_mma_32x16_layout_sr_a, shared_16x32_to_mma_32x16_layout_sr_b, + mma_load_a_32x4_to_shared_16x8_layout, + mma_load_b_32x4_to_shared_16x8_layout, mma_load_a_32x16_to_shared_16x32_layout, mma_load_b_32x16_to_shared_16x32_layout, ) @@ -105,9 +109,14 @@ def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] def _initialize_mma_prefix(self, k_dim: int = 16): - if k_dim == 16: + if k_dim == 8: + # typically used for tfloat32 + self.mma_prefix = "m16n8k8" + elif k_dim == 16: + # typically used for float16/bfloat16 self.mma_prefix = "m16n8k16" elif k_dim == 32: + # typically used for int8/fp8 self.mma_prefix = "m16n8k32" else: raise ValueError("Unsupported k_dim") @@ -201,7 +210,15 @@ def ldmatrix_a(self, a_dtype = self.a_dtype a_transposed = self.a_transposed # ldmatrix cannot be used for int8 + trans case. - ldmatrix_available = not (DataType(a_dtype).bits == 8 and a_transposed) + ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) + mma_load_layout = lambda i, j: (i, j) + if not ldmatrix_available: + if DataType(a_dtype).bits == 8: + mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout + elif DataType(a_dtype).bits == 32: + mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout + else: + raise ValueError(f"Unsupported dtype: {a_dtype}") thread_binding = self.get_thread_binding() @@ -235,7 +252,7 @@ def _warp_ldmatrix_a( ) else: for j in T.serial(local_size_a): - mi, mk = mma_load_a_32x16_to_shared_16x32_layout(tx, j) + mi, mk = mma_load_layout(tx, j) A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -256,7 +273,15 @@ def ldmatrix_b(self, thread_binding = self.get_thread_binding() replicate_b = (self.n_dim == 16) # ldmatrix cannot be used for int8 + trans case. - ldmatrix_available = not (DataType(b_dtype).bits == 8 and not b_transposed) + ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) + mma_load_layout = lambda i, j: (i, j) + if not ldmatrix_available: + if DataType(b_dtype).bits == 8: + mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout + elif DataType(b_dtype).bits == 32: + mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout + else: + raise ValueError(f"Unsupported dtype: {b_dtype}") @T.macro def _warp_ldmatrix_b( @@ -296,7 +321,7 @@ def _warp_ldmatrix_b( # load 16x32 data from shared buffer to local buffer # must be transposed. for j in T.serial(local_size_b): - mi, mk = mma_load_b_32x16_to_shared_16x32_layout(tx, j) + mi, mk = mma_load_layout(tx, j) B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -458,7 +483,11 @@ def make_mma_load_layout(self, # then rs also can represent a transposed basic layout transform_func_sr_a: Callable = None transform_func_sr_b: Callable = None - if dtype_bits == 16: + if dtype_bits == 32: + ... + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b + elif dtype_bits == 16: transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b elif dtype_bits == 8: diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 08730a40a..a48801b1d 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -1,9 +1,11 @@ from tvm import DataType from typing import Literal from .mma_layout import ( + ldmatrix_32x4_to_shared_16x8_layout_a, + ldmatrix_32x4_to_shared_16x8_layout_b, ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_16x32_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, mma_store_32x8_to_shared_16x16_layout, ) @@ -26,7 +28,18 @@ def get_ldmatrix_offset( ): assert matrix in ["A", "B"], "matrix should be either A or B" dtype_bits = DataType(dtype).bits - if dtype_bits == 16: + if dtype_bits == 32: + if matrix == "B" and transposed: + transform_func = ldmatrix_32x4_to_shared_16x8_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_32x4_to_shared_16x8_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") + elif dtype_bits == 16: transform_func = ldmatrix_32x8_to_shared_16x16_layout transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout if transposed: @@ -41,7 +54,7 @@ def get_ldmatrix_offset( new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx elif matrix == "A" and not transposed: - transform_func = ldmatrix_16x32_to_shared_16x32_layout_a + transform_func = ldmatrix_32x16_to_shared_16x32_layout_a new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx else: diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index e9e9490d9..724187205 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -2,65 +2,53 @@ from tilelang import tvm as tvm from tvm.target import Target from tvm import tir -from tvm.tir import Var from tilelang.utils.language import is_shared, is_fragment from tilelang.ir import GemmWarpPolicy +from tvm.ir.base import Node + @dataclass class GemmBase(object): - gemm_node: "GemmPy" + gemm_node: Node def infer_layout(self, target: Target, thread_nums: int): raise NotImplementedError("infer_layout is not implemented") - + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): raise NotImplementedError("lower is not implemented") - def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) def is_gemm_sr(self) -> bool: return is_shared(self.A) and is_fragment(self.B) - + def is_gemm_rs(self) -> bool: return is_fragment(self.A) and is_shared(self.B) - + def is_gemm_rr(self) -> bool: return is_fragment(self.A) and is_fragment(self.B) @property def M(self) -> int: return self.gemm_node.M - + @property def N(self) -> int: return self.gemm_node.N - + @property def K(self) -> int: return self.gemm_node.K - + @property def trans_A(self) -> bool: return self.gemm_node.trans_A - + @property def trans_B(self) -> bool: return self.gemm_node.trans_B - - @property - def in_dtype(self) -> str: - return self.gemm_node.in_dtype - - @property - def accum_dtype(self) -> str: - return self.gemm_node.accum_dtype - - @property - def chunk(self) -> int: - return self.gemm_node.chunk - + @property def in_dtype(self) -> str: assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" @@ -73,31 +61,31 @@ def accum_dtype(self) -> str: @property def chunk(self) -> int: return self.A.shape[-2] if self.trans_A else self.A.shape[-1] - + @property def A(self) -> tir.Buffer: return self.gemm_node.A - + @property def B(self) -> tir.Buffer: return self.gemm_node.B - + @property def C(self) -> tir.Buffer: return self.gemm_node.C - + @property def APtr(self) -> tir.PrimExpr: return self.gemm_node.APtr - + @property def BPtr(self) -> tir.PrimExpr: return self.gemm_node.BPtr - + @property def CPtr(self) -> tir.PrimExpr: return self.gemm_node.CPtr - + @property def stride_A(self) -> int: return self.gemm_node.stride_A @@ -105,27 +93,27 @@ def stride_A(self) -> int: @property def stride_B(self) -> int: return self.gemm_node.stride_B - + @property def offset_A(self) -> int: return self.gemm_node.offset_A - + @property def offset_B(self) -> int: return self.gemm_node.offset_B - + @property def clear_accum(self) -> bool: return self.gemm_node.clear_accum - + @property def k_pack(self) -> int: return self.gemm_node.k_pack - + @property def wg_wait(self) -> int: return self.gemm_node.wg_wait - + @property def policy(self) -> GemmWarpPolicy: return self.gemm_node.policy diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index fa82a2034..a046ee126 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -45,7 +45,7 @@ def infer_layout(self, target: Target, thread_nums: int): return { self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), self.B: make_swizzled_layout(self.B), - self.C: mma_emitter.make_mma_store_layout(self.C), + self.C: mma_emitter.make_mma_store_layout(self.C), } elif self.is_gemm_rr(): return { @@ -54,7 +54,8 @@ def infer_layout(self, target: Target, thread_nums: int): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError( + f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, @@ -87,6 +88,7 @@ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): C_local = self.C if self.is_gemm_ss(): + @T.prim_func def _gemm_ssr() -> None: """ @@ -120,6 +122,7 @@ def _gemm_ssr() -> None: return _Simplify(_gemm_ssr, inline_let=True) elif self.is_gemm_sr(): B_local = self.B + @T.prim_func def _gemm_srr() -> None: """ @@ -130,7 +133,7 @@ def _gemm_srr() -> None: A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) for ki in T.serial(0, (block_K // micro_size_k)): - + # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -144,10 +147,11 @@ def _gemm_srr() -> None: # Simplify to optimize the index computing # Must inline let statements to simplify the analysis # alloc_buffers body - # insert into parrent block + # insert into parent block return _Simplify(_gemm_srr, inline_let=True) elif self.is_gemm_rs(): A_local = self.A + @T.prim_func def _gemm_rsr() -> None: """ @@ -158,7 +162,7 @@ def _gemm_rsr() -> None: B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) for ki in T.serial(0, (block_K // micro_size_k)): - + # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -175,6 +179,7 @@ def _gemm_rsr() -> None: elif self.is_gemm_rr(): A_local = self.A B_local = self.B + @T.prim_func def _gemm_rsr() -> None: """ @@ -191,17 +196,17 @@ def _gemm_rsr() -> None: # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) else: - raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") - + raise ValueError( + f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) def is_gemm_sr(self) -> bool: return is_shared(self.A) and is_fragment(self.B) - + def is_gemm_rs(self) -> bool: return is_fragment(self.A) and is_shared(self.B) - + def is_gemm_rr(self) -> bool: return is_fragment(self.A) and is_fragment(self.B) From 3da08a1c9873e650f543b90744d0a26963051fa8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Sep 2025 01:20:11 +0800 Subject: [PATCH 08/10] lint fix --- .clang-tidy | 1 + src/target/codegen_cuda.cc | 2 +- src/target/ptx.cc | 440 ++++++++++-------- src/target/ptx.h | 105 +++-- .../test_tilelang_tilelibrary_gemm.py | 22 +- tilelang/intrinsics/mma_layout.py | 7 +- tilelang/intrinsics/mma_macro_generator.py | 10 +- 7 files changed, 314 insertions(+), 273 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 7d796085d..8631d9211 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -41,6 +41,7 @@ Checks: > -clang-analyzer-optin.cplusplus.UninitializedObject, -cppcoreguidelines-pro-type-static-cast-downcast, -performance-unnecessary-value-param, + -performance-enum-size, WarningsAsErrors: '*' diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index a2f58b67b..21dc509cf 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -14,8 +14,8 @@ #include #include "../op/builtin.h" -#include "arith/pattern_match.h" #include "./ptx.h" +#include "arith/pattern_match.h" namespace tvm { namespace codegen { diff --git a/src/target/ptx.cc b/src/target/ptx.cc index f872cad0b..14d1b0460 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -69,12 +69,13 @@ enum class DataType : int { kBit64 = 22 }; -static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", - ".s32", ".u32", ".s64", ".u64", ".e4m3", ".e5m2", - ".f16", ".bf16", ".f16x2", ".f32", ".tf32", ".f64", - ".b1", ".b8", ".b16", ".b32", ".b64"}; -static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, - 16, 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; +static const char *dtype_str[] = { + ".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32", + ".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32", + ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, + 64, 64, 8, 8, 16, 16, 32, 32, + 32, 64, 1, 8, 16, 32, 64}; /*! * \brief Create PTX data type from string. @@ -134,22 +135,27 @@ inline DataType DTypeFromString(const std::string str) { /*! * \brief Get the string representation of given PTX data type. */ -inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } +inline std::string DTypeToString(DataType dtype) { + return dtype_str[static_cast(dtype)]; +} /*! * \brief Get the number of bits of given PTX data type. */ -inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } +inline uint32_t DTypeBits(DataType dtype) { + return num_bits[static_cast(dtype)]; +} /*! * \brief Extract the value m, n, k from string m*n*k* */ -inline std::tuple ParseMMAShape(const std::string& str) { - size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); +inline std::tuple ParseMMAShape(const std::string &str) { + size_t pos_m = str.find('m'), pos_n = str.find('n'), pos_k = str.find('k'); CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) << "Cannot parse MMA shape " << str; int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), - n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); + n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), + k = std::stoi(str.substr(pos_k + 1)); return std::make_tuple(m, n, k); } @@ -161,7 +167,7 @@ enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 }; /*! * \brief Parse layout type */ -LayoutType LayoutTypeFromString(const std::string& str) { +LayoutType LayoutTypeFromString(const std::string &str) { if (str == "row") { return LayoutType::kRowMajor; } else if (str == "col") { @@ -171,7 +177,7 @@ LayoutType LayoutTypeFromString(const std::string& str) { } } -static const char* layout_type_str[] = {"row", "col"}; +static const char *layout_type_str[] = {"row", "col"}; /*! * \brief Convert layout type to string. @@ -184,15 +190,18 @@ inline std::string LayoutTypeToString(LayoutType layout) { * \brief MMA Configurations, used to determine validity. */ struct MMAConfig { - explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse) - : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {} + explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, + bool sparse) + : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), + sparse(sparse) {} int m, n, k; DataType dtype_mul; bool use_bit_op; bool sparse; - inline bool operator==(const MMAConfig& other) { - return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul && - use_bit_op == other.use_bit_op && sparse == other.sparse; + inline bool operator==(const MMAConfig &other) { + return m == other.m && n == other.n && k == other.k && + dtype_mul == other.dtype_mul && use_bit_op == other.use_bit_op && + sparse == other.sparse; } }; @@ -248,77 +257,86 @@ const MMAConfig valid_mma_configs[] = { }; /*! - * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA - * computation. - * \param dtype_a The data type of multiplicand a. + * \brief Check whether the multiplicand data type and accumulator data type is + * valid for MMA computation. \param dtype_a The data type of multiplicand a. * \param dtype_b The data type of multiplicand b. * \param dtype_c The data type of accumulator c. * \note Reference: * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types */ -void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { - std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + +void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, + DataType dtype_c) { + std::string ab_not_match_err_str = "The multiplicands' data type " + + DTypeToString(dtype_a) + DTypeToString(dtype_b) + " do not match."; // check a and b switch (dtype_a) { - case DataType::kBit1: - case DataType::kFloat16: - case DataType::kBFloat16: - case DataType::kFloat32: - case DataType::kTensorFloat32: - case DataType::kFloat64: - CHECK(dtype_a == dtype_b) << ab_not_match_err_str; - break; - case DataType::kInt4: - case DataType::kUInt4: - CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; - break; - case DataType::kInt8: - case DataType::kUInt8: - CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; - break; - case DataType::kFloat8_e4m3: - case DataType::kFloat8_e5m2: - CHECK(dtype_b == DataType::kFloat8_e4m3 || dtype_b == DataType::kFloat8_e5m2) - << ab_not_match_err_str; - break; - default: - CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) - << DTypeToString(dtype_b); + case DataType::kBit1: + case DataType::kFloat16: + case DataType::kBFloat16: + case DataType::kFloat32: + case DataType::kTensorFloat32: + case DataType::kFloat64: + CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + break; + case DataType::kInt4: + case DataType::kUInt4: + CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) + << ab_not_match_err_str; + break; + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) + << ab_not_match_err_str; + break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_b == DataType::kFloat8_e4m3 || + dtype_b == DataType::kFloat8_e5m2) + << ab_not_match_err_str; + break; + default: + CHECK(false) << "Invalid multiplicand data types: " + << DTypeToString(dtype_a) << DTypeToString(dtype_b); } // check a,b and c switch (dtype_a) { - case DataType::kBit1: - case DataType::kInt4: - case DataType::kUInt4: - case DataType::kInt8: - case DataType::kUInt8: - CHECK(dtype_c == DataType::kInt32) - << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) - << ", accumulator data type should be s32."; - break; - case DataType::kFloat16: - CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) - << "For multiplicand data type f16, accumulator data type should be f16/f32."; - break; - case DataType::kBFloat16: - case DataType::kFloat32: - case DataType::kTensorFloat32: - CHECK(dtype_c == DataType::kFloat32) - << "For multiplicand data type bf16/tf32, accumulator data type can only be f32."; - break; - case DataType::kFloat64: - CHECK(dtype_c == DataType::kFloat64) - << "For multiplicand data type f64, accumulator data type can only be f64."; - break; - case DataType::kFloat8_e4m3: - case DataType::kFloat8_e5m2: - CHECK(dtype_c == DataType::kFloat32) - << "For multiplicand data type e4m3/e5m2, accumulator data type can only be f32."; - break; - default: - CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) - << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_c == DataType::kInt32) + << "For multiplicand data type " << DTypeToString(dtype_a) + << DTypeToString(dtype_b) << ", accumulator data type should be s32."; + break; + case DataType::kFloat16: + CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) + << "For multiplicand data type f16, accumulator data type should be " + "f16/f32."; + break; + case DataType::kBFloat16: + case DataType::kFloat32: + case DataType::kTensorFloat32: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can " + "only be f32."; + break; + case DataType::kFloat64: + CHECK(dtype_c == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be " + "f64."; + break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type e4m3/e5m2, accumulator data type can " + "only be f32."; + break; + default: + CHECK(false) << "Invalid multiplicand/accumulator data types: " + << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << DTypeToString(dtype_c) << "."; } } @@ -332,37 +350,41 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ * \param dtype_a The data type of multiplicand A. * \param dtype_b The data type of multiplicand B. * \param dtype_c The data type of accumulator C. - * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not - * 1-bit MMA). - * \param sparse Whether it's Sparse MMA or not. + * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" + * or ""(if it's not 1-bit MMA). \param sparse Whether it's Sparse MMA or not. * \param saturate Whether saturate output or not. */ -void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b, - DataType dtype_a, DataType dtype_b, DataType dtype_c, - const std::string& bit_op, bool sparse, bool saturate) { - CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "") +void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, + LayoutType layout_b, DataType dtype_a, + DataType dtype_b, DataType dtype_c, + const std::string &bit_op, bool sparse, + bool saturate) { + CHECK(bit_op == "xor" || bit_op == "and" || bit_op.empty()) << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; bool use_bit_op = !bit_op.empty(); if (use_bit_op) { - CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand."; + CHECK(dtype_a == DataType::kBit1) + << "Bit operator is only compatible with 1-bit multiplicand."; } CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); if (saturate) { - CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || - dtype_a == DataType::kUInt8) - << "Output saturation only applicable to multiplicand type s4/u4/s8/u8."; + CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || + dtype_a == DataType::kInt8 || dtype_a == DataType::kUInt8) + << "Output saturation only applicable to multiplicand type " + "s4/u4/s8/u8."; } if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { // Only MMA on m8n8k4 for fp16 supports customized layouts. - CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) + CHECK(layout_a == LayoutType::kRowMajor && + layout_b == LayoutType::kColumnMajor) << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," << LayoutTypeToString(layout_b) << "."; } MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse); bool match = false; - for (const MMAConfig& valid_config : valid_mma_configs) { + for (const MMAConfig &valid_config : valid_mma_configs) { if (config == valid_config) { match = true; break; @@ -375,7 +397,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType * \brief Fragment attributes */ class FragAttrs { - public: +public: explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type) : reg_type(reg_type), size(size), ptr_type(ptr_type) {} /*! \brief PTX register type */ @@ -391,43 +413,44 @@ class FragAttrs { */ inline FragAttrs GetFragAttrs(DataType dtype) { switch (dtype) { - case DataType::kBit1: - case DataType::kInt4: - case DataType::kUInt4: - case DataType::kInt8: - case DataType::kUInt8: - case DataType::kFloat8_e4m3: - case DataType::kFloat8_e5m2: - case DataType::kBit16: - case DataType::kFloat16: // .f16x2 register - case DataType::kBFloat16: - case DataType::kTensorFloat32: - return FragAttrs('r', 32, "(unsigned *)"); - case DataType::kInt32: - return FragAttrs('r', 32, "(int *)"); - case DataType::kFloat32: - return FragAttrs('f', 32, "(float *)"); - case DataType::kFloat64: - return FragAttrs('d', 64, "(double *)"); - default: - ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; - return FragAttrs('\0', 0, ""); + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + case DataType::kBit16: + case DataType::kFloat16: // .f16x2 register + case DataType::kBFloat16: + case DataType::kTensorFloat32: + return FragAttrs('r', 32, "(unsigned *)"); + case DataType::kInt32: + return FragAttrs('r', 32, "(int *)"); + case DataType::kFloat32: + return FragAttrs('f', 32, "(float *)"); + case DataType::kFloat64: + return FragAttrs('d', 64, "(double *)"); + default: + ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; + return FragAttrs('\0', 0, ""); } } -}; // namespace ptx +}; // namespace ptx /*! * \brief Replace patterns with replacement strings. * \note should use std::format instead when codebase is ported to C++20. */ class Replacer { - public: - void register_rule(const std::string& pattern, const std::string& replacement) { +public: + void register_rule(const std::string &pattern, + const std::string &replacement) { _rules.emplace_back(pattern, replacement); } std::string rewrite(std::string str) { - for (auto&& rule : _rules) { + for (auto &&rule : _rules) { auto [pattern, replacement] = rule; size_t len = pattern.size(); size_t new_len = replacement.size(); @@ -441,14 +464,15 @@ class Replacer { } void empty_rules() { _rules.clear(); } - private: +private: std::vector> _rules; }; /*! * \brief Get the number of MMA computations for given shape and datatype. */ -inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) { +inline uint32_t GetNumMMAComputations(int m, int n, int k, + ptx::DataType dtype) { if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) { // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one. return 4; @@ -458,30 +482,29 @@ inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) } /*! - * \brief Return template string, input operands string and output operands string. - * \param m The M in mMnNkK of MMA instructions. - * \param n The N in mMnNkK of MMA instructions. - * \param k The K in mMnNkK of MMA instructions. + * \brief Return template string, input operands string and output operands + * string. \param m The M in mMnNkK of MMA instructions. \param n The N in + * mMnNkK of MMA instructions. \param k The K in mMnNkK of MMA instructions. * \param dtype_a The data type of multiplicand a. * \param dtype_b The data type of multiplicand b. * \param dtype_c The data type of accumulator c. * \param sparse Whether it's Sparse MMA or not. */ -inline std::tuple GetMMAOperands(int m, int n, int k, - ptx::DataType dtype_a, - ptx::DataType dtype_b, - ptx::DataType dtype_c, - bool sparse) { +inline std::tuple +GetMMAOperands(int m, int n, int k, ptx::DataType dtype_a, + ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse) { std::stringstream templates, inputs, outputs; const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), frag_attr_b = ptx::GetFragAttrs(dtype_b), frag_attr_c = ptx::GetFragAttrs(dtype_c); constexpr uint32_t warp_size = 32; const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); - const int num_operands_a = - (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1), - num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, - num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; + const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) / + frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_b = + (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, + num_operands_c = + (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; // generate templates; int arg_counter = 0; @@ -516,16 +539,16 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { inputs << ", "; } - inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i - << "])"; + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type + << "(A))[" << i << "])"; } for (int i = 0; i < num_operands_b; ++i) { - inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i - << "])"; + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type + << "(B))[" << i << "])"; } for (int i = 0; i < num_operands_c; ++i) { - inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i - << "])"; + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type + << "(C))[" << i << "])"; } // input of metadata for sparse mma. if (sparse) { @@ -537,22 +560,25 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { outputs << ","; } - outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i - << "])"; + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type + << "(D))[" << i << "])"; } return std::make_tuple(templates.str(), inputs.str(), outputs.str()); } -std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, - const std::string& B_layout, const std::string& A_dtype, - const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ptr, const std::string& a_elem_offset, - const std::string& b_ptr, const std::string& b_elem_offset, - const std::string& c_ptr, const std::string& c_elem_offset, - const std::string& metadata, const std::string& metadata_offset, - const std::string& sparsity_selector, const std::string& bit_op, - bool sparse, bool saturate) { - ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), +std::string +PrintMMAAssembly(const std::string &shape, const std::string &A_layout, + const std::string &B_layout, const std::string &A_dtype, + const std::string &B_dtype, const std::string &C_dtype, + const std::string &a_ptr, const std::string &a_elem_offset, + const std::string &b_ptr, const std::string &b_elem_offset, + const std::string &c_ptr, const std::string &c_elem_offset, + const std::string &metadata, + const std::string &metadata_offset, + const std::string &sparsity_selector, + const std::string &bit_op, bool sparse, bool saturate) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), + dtype_b = ptx::DTypeFromString(B_dtype), dtype_c = ptx::DTypeFromString(C_dtype); if (dtype_a == ptx::DataType::kFloat32) { dtype_a = ptx::DataType::kTensorFloat32; @@ -563,8 +589,8 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), layout_b = ptx::LayoutTypeFromString(B_layout); auto [m, n, k] = ptx::ParseMMAShape(shape); - CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse, - saturate); + CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, + bit_op, sparse, saturate); std::string asm_code = R"( { __asm__ __volatile__( @@ -588,7 +614,8 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b)); replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c)); replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c)); - replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{.bitop}", + bit_op.empty() ? "" : "." + bit_op + ".popc"); replacer.register_rule("{templates}", templates_str); replacer.register_rule("{outputs}", outputs_str); replacer.register_rule("{inputs}", inputs_str); @@ -604,8 +631,9 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo return asm_code; } -inline std::tuple GetLoadMatrixOperands( - int num, const std::string& local_ptr, const std::string& local_elem_offset) { +inline std::tuple +GetLoadMatrixOperands(int num, const std::string &local_ptr, + const std::string &local_elem_offset) { std::stringstream templates, outputs; int arg_counter = 0; // generate templates @@ -620,20 +648,23 @@ inline std::tuple GetLoadMatrixOperands( if (i != 0) { outputs << ", "; } - outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))[" - << i << "])"; + outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " + << local_elem_offset << "))[" << i << "])"; } return std::make_tuple(templates.str(), outputs.str()); } -std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, - const std::string& local_ptr, - const std::string& local_elem_offset, - const std::string& smem_ptr, - const std::string& smem_elem_offset) { - CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices."; +std::string PrintLoadMatrixAssembly(bool trans, int num, + const std::string &type, + const std::string &local_ptr, + const std::string &local_elem_offset, + const std::string &smem_ptr, + const std::string &smem_elem_offset) { + CHECK(num == 1 || num == 2 || num == 4) + << "ldmatrix only accept loading 1/2/4 matrices."; ptx::DataType data_type = ptx::DTypeFromString(type); - CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16."; + CHECK(data_type == ptx::DataType::kBit16) + << "ldmatrix only accept matrix with type .b16."; std::string asm_code = R"( { unsigned int addr = cast_smem_ptr_to_int({smem_addr}); @@ -645,7 +676,8 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type ); } )"; - auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset); + auto [templates_str, outputs_str] = + GetLoadMatrixOperands(num, local_ptr, local_elem_offset); Replacer replacer; replacer.register_rule("{.shape}", ".m8n8"); @@ -660,10 +692,11 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type return asm_code; } -std::string PrintCpAsyncAssembly(const std::string& shared_ptr, - const std::string& shared_elem_offset, - const std::string& global_ptr, - const std::string& global_elem_offset, const std::string& bytes) { +std::string PrintCpAsyncAssembly(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes) { std::string asm_code = R"( { unsigned int addr = cast_smem_ptr_to_int({smem_addr}); @@ -678,22 +711,22 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, } )"; Replacer replacer; - replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); - replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{smem_addr}", + shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", + global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); asm_code = replacer.rewrite(asm_code); return asm_code; } -std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, - const std::string& shared_elem_offset, - const std::string& global_ptr, - const std::string& global_elem_offset, - const std::string& bytes, - const std::string& predicate_value) { - CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || - bytes == "1") +std::string PrintPredicatedCpAsyncAssembly( + const std::string &shared_ptr, const std::string &shared_elem_offset, + const std::string &global_ptr, const std::string &global_elem_offset, + const std::string &bytes, const std::string &predicate_value) { + CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || + bytes == "2" || bytes == "1") << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; std::string predicated_asm_code = R"( { @@ -712,14 +745,16 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, ); } )"; - auto [store_shared, nopreg] = [](const std::string& bytes) { + auto [store_shared, nopreg] = [](const std::string &bytes) { if (bytes == "16") return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}", "\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)"); else if (bytes == "12") - return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)"); + return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", + "\"r\"(0), \"r\"(0), \"r\"(0)"); else if (bytes == "8") - return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)"); + return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", + "\"r\"(0), \"r\"(0)"); else if (bytes == "4") return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)"); else if (bytes == "2") @@ -731,8 +766,10 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, }(bytes); Replacer replacer; - replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); - replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{smem_addr}", + shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", + global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); replacer.register_rule("{store_shared}", store_shared); @@ -742,11 +779,12 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, return predicated_asm_code; } -std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, - const std::string& shared_elem_offset, - const std::string& global_ptr, - const std::string& global_elem_offset, const std::string& bytes, - const std::string& barrier) { +std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes, + const std::string &barrier) { std::string asm_code = R"( { unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); @@ -760,15 +798,17 @@ std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, )"; Replacer replacer; - replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); - replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{smem_addr}", + shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", + global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{barrier}", "&" + barrier); asm_code = replacer.rewrite(asm_code); return asm_code; } -std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { +std::string PrintCpAsyncBarrierAsm(const std::string &barrier) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -785,8 +825,8 @@ std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { return predicated_asm_code; } -std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, - const std::string& thread_count) { +std::string PrintInitBarrierThreadCountAsm(const std::string &barrier, + const std::string &thread_count) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -805,7 +845,7 @@ std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, return predicated_asm_code; } -std::string PrintArriveBarrierAsm(const std::string& barrier) { +std::string PrintArriveBarrierAsm(const std::string &barrier) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -822,8 +862,8 @@ std::string PrintArriveBarrierAsm(const std::string& barrier) { return predicated_asm_code; } -std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, - const std::string& byte_count) { +std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, + const std::string &byte_count) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -842,7 +882,7 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, return predicated_asm_code; } -std::string PrintWaitBarrierAsm(const std::string& barrier) { +std::string PrintWaitBarrierAsm(const std::string &barrier) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -860,5 +900,5 @@ std::string PrintWaitBarrierAsm(const std::string& barrier) { return predicated_asm_code; } -} // namespace codegen -} // namespace tvm::tl +} // namespace codegen +} // namespace tvm::tl diff --git a/src/target/ptx.h b/src/target/ptx.h index 72691fd44..15acb96b1 100644 --- a/src/target/ptx.h +++ b/src/target/ptx.h @@ -49,35 +49,38 @@ namespace codegen { * \param metadata Pointer to metadata buffer (only used for sparse mma). * \param metadata_offset The offset of element in metadata. * \param sparsity_selector The sparsity selector in sparse mma. - * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or "and". - * \param sparse Whether it's sparse mma or not. - * \param saturate Whether saturate output or not. + * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or + * "and". \param sparse Whether it's sparse mma or not. \param saturate Whether + * saturate output or not. */ -std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, - const std::string& B_layout, const std::string& A_dtype, - const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ptr, const std::string& a_offset, - const std::string& b_ptr, const std::string& b_offset, - const std::string& c_ptr, const std::string& c_offset, - const std::string& metadata, const std::string& metadata_offset, - const std::string& sparsity_selector, const std::string& bit_op, - bool sparse, bool saturate); +std::string +PrintMMAAssembly(const std::string &shape, const std::string &A_layout, + const std::string &B_layout, const std::string &A_dtype, + const std::string &B_dtype, const std::string &C_dtype, + const std::string &a_ptr, const std::string &a_offset, + const std::string &b_ptr, const std::string &b_offset, + const std::string &c_ptr, const std::string &c_offset, + const std::string &metadata, + const std::string &metadata_offset, + const std::string &sparsity_selector, + const std::string &bit_op, bool sparse, bool saturate); /*! * \brief Print ldmatrix assembly string given parameters. * \param trans: whether the matrix is loaded in column major format or not. * \param num: number of matrices to load. - * \param type: The data type in the matrix, .b16 is the only accepted data type. - * \param local_ptr: pointer to local buffer. - * \param local_elem_offset: The offset of the element to store in the local buffer. - * \param smem_ptr: pointer to the shared memory buffer to load. - * \param smem_elem_offset: The offset of the start element of the row to load in shared memory. + * \param type: The data type in the matrix, .b16 is the only accepted data + * type. \param local_ptr: pointer to local buffer. \param local_elem_offset: + * The offset of the element to store in the local buffer. \param smem_ptr: + * pointer to the shared memory buffer to load. \param smem_elem_offset: The + * offset of the start element of the row to load in shared memory. */ -std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, - const std::string& local_ptr, - const std::string& local_elem_offset, - const std::string& smem_ptr, - const std::string& smem_elem_offset); +std::string PrintLoadMatrixAssembly(bool trans, int num, + const std::string &type, + const std::string &local_ptr, + const std::string &local_elem_offset, + const std::string &smem_ptr, + const std::string &smem_elem_offset); /*! * \brief Print ptx cp.async assembly string given parameters. @@ -87,10 +90,11 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type * \param global_elem_offset: The offset into the global memory. * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. */ -std::string PrintCpAsyncAssembly(const std::string& shared_ptr, - const std::string& shared_elem_offset, - const std::string& global_ptr, - const std::string& global_elem_offset, const std::string& bytes); +std::string PrintCpAsyncAssembly(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes); /*! * \brief Print predicated ptx cp.async assembly string given parameters. @@ -101,12 +105,10 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. * \param predicate_value: The value of predicate `@p`. */ -std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, - const std::string& shared_elem_offset, - const std::string& global_ptr, - const std::string& global_elem_offset, - const std::string& bytes, - const std::string& predicate_value); +std::string PrintPredicatedCpAsyncAssembly( + const std::string &shared_ptr, const std::string &shared_elem_offset, + const std::string &global_ptr, const std::string &global_elem_offset, + const std::string &bytes, const std::string &predicate_value); /*! * \brief Print ptx async copy from global to shared memory using cp.async.bulk @@ -117,48 +119,49 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, * \param bytes: The number of bytes to copy. * \param barrier: The name of the barrier in shared memory. */ -std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, - const std::string& shared_elem_offset, - const std::string& global_ptr, - const std::string& global_elem_offset, const std::string& bytes, - const std::string& barrier); +std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes, + const std::string &barrier); /*! * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive * \param barrier: The name of the barrier in shared memory. */ -std::string PrintCpAsyncBarrierAsm(const std::string& barrier); +std::string PrintCpAsyncBarrierAsm(const std::string &barrier); /*! * \brief Print ptx barrier initialization of thread count using mbarrier.init * \param barrier: The name of the barrier in shared memory. * \param thread_count: The number of threads expected to arrive at the barrier. */ -std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, - const std::string& thread_count); +std::string PrintInitBarrierThreadCountAsm(const std::string &barrier, + const std::string &thread_count); /*! * \brief Print ptx barrier arrival using mbarrier.arrive * \param barrier: The name of the barrier in shared memory. */ -std::string PrintArriveBarrierAsm(const std::string& barrier); +std::string PrintArriveBarrierAsm(const std::string &barrier); /*! - * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx - * \param barrier: The name of the barrier in shared memory. - * \param byte_count: Increases the tx count of the mbarrier object to track completion of - * addtional async transactions. + * \brief Print ptx barrier arrival with expect tx operation using + * mbarrier.arrive.expect_tx \param barrier: The name of the barrier in shared + * memory. \param byte_count: Increases the tx count of the mbarrier object to + * track completion of addtional async transactions. */ -std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, - const std::string& byte_count); +std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, + const std::string &byte_count); /*! * \brief Print ptx barrier wait using mbarrier.try_wait * \param barrier: The name of the barrier in shared memory. */ -std::string PrintWaitBarrierAsm(const std::string& barrier); +std::string PrintWaitBarrierAsm(const std::string &barrier); -} // namespace codegen -} // namespace tvm::tl +} // namespace codegen +} // namespace tvm::tl -#endif // TVM_TL_TARGET_SOURCE_PTX_H_ +#endif // TVM_TL_TARGET_SOURCE_PTX_H_ diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 881c975e0..49ec95ce4 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -132,6 +132,7 @@ def test_gemm_ss(): run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + def matmul_rs( M, N, @@ -265,6 +266,7 @@ def test_gemm_rs(): run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + def matmul_sr( M, N, @@ -537,22 +539,6 @@ def test_gemm_rr(): run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + if __name__ == "__main__": - # tilelang.testing.main() - # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float32", 128, 128, 32, 0) - # tilelang.disable_cache() - - run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 0) - - run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 0) - - run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 0) - run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 0) + tilelang.testing.main() diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index f1b3d3d82..8ddd9f96d 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -8,6 +8,7 @@ def ldmatrix_32x4_to_shared_16x8_layout_a(thread_id, local_id): col = (thread_id // 16) * 4 + local_id % 4 return row, col + def ldmatrix_32x4_to_shared_16x8_layout_b(thread_id, local_id): row = (thread_id // 16) * 8 + (thread_id % 8) col = ((thread_id % 16) // 8) * 4 + local_id % 4 @@ -26,7 +27,6 @@ def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col - def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): row = thread_id % 16 col = local_id + (thread_id // 16) * 16 @@ -52,22 +52,27 @@ def shared_16x8_to_mma_a_32x4_layout(i, j): thread_id = 4 * (i % 8) + (j % 4) return thread_id, 2 * (j // 4) + (i // 8) + def shared_16x8_to_mma_a_32x4_layout_trans(i, j): return shared_16x8_to_mma_a_32x4_layout(j, i) + # mma.sync matrix B layout, if wanna trans, please apply map_indices def shared_16x8_to_mma_b_32x4_layout(i, j): thread_id = 4 * (i % 8) + (j % 4) return thread_id, 2 * (i // 8) + (j // 4) + def shared_16x8_to_mma_b_32x4_layout_trans(i, j): return shared_16x8_to_mma_b_32x4_layout(j, i) + shared_16x8_to_mma_32x4_layout_sr_a = shared_16x8_to_mma_a_32x4_layout shared_16x8_to_mma_32x4_layout_sr_b = shared_16x8_to_mma_b_32x4_layout shared_16x8_to_mma_32x4_layout_rs_a = shared_16x8_to_mma_a_32x4_layout_trans shared_16x8_to_mma_32x4_layout_rs_b = shared_16x8_to_mma_b_32x4_layout_trans + def shared_16x16_to_mma_a_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index c4cd0470d..cb999ac41 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -211,7 +211,10 @@ def ldmatrix_a(self, a_transposed = self.a_transposed # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) - mma_load_layout = lambda i, j: (i, j) + + def mma_load_layout(i, j): + return i, j + if not ldmatrix_available: if DataType(a_dtype).bits == 8: mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout @@ -274,7 +277,10 @@ def ldmatrix_b(self, replicate_b = (self.n_dim == 16) # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) - mma_load_layout = lambda i, j: (i, j) + + def mma_load_layout(i, j): + return i, j + if not ldmatrix_available: if DataType(b_dtype).bits == 8: mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout From aa62efbf09ba80d5306d3cc0504f63f7ca949446 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Sep 2025 01:20:35 +0800 Subject: [PATCH 09/10] lint fix --- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 49ec95ce4..1a8d7b93e 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -90,7 +90,6 @@ def run_gemm_ss( tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - print(kernel.get_kernel_source()) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): From b5f327c4ce7f4287d44a2d8f7977299f22341a6e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Sep 2025 02:21:50 +0800 Subject: [PATCH 10/10] Refactor shared memory allocation in GEMM tests - Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`. - This change simplifies the allocation process and aligns with the updated GEMM function signatures. --- .../test_tilelang_tilelibrary_gemm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 1a8d7b93e..984326434 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -31,8 +31,8 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -162,8 +162,8 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) @@ -296,8 +296,8 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) @@ -431,8 +431,8 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)