From e82712fd53f4134f7817f80db44c8464ca5c6509 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 14 May 2025 19:43:44 +0800 Subject: [PATCH 1/6] Remove debug print statement from block_sparse_attn_triton.py and implement a timeout handler in autotuner for function execution. This enhances the robustness of the autotuner by allowing it to handle timeouts gracefully. --- .../block_sparse_attn_triton.py | 2 -- tilelang/autotuner/__init__.py | 24 ++++++++++++++----- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 0c41c99e4..d81903b06 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -62,8 +62,6 @@ def _fwd_kernel_inner( mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) # print - if k_block_col_idx == 3: - print("mask_val", mask_val) if mask_val == True: start_n = k_block_col_idx * BLOCK_N # -- compute qk ---- diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 2e71b8b73..18d6f5323 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -19,6 +19,22 @@ import torch import os import sys +import signal + +class TimeoutException(Exception): + pass + +def timeout_handler(signum, frame): + raise TimeoutException() + +def run_with_timeout(func, timeout, *args, **kwargs): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout) + try: + result = func(*args, **kwargs) + finally: + signal.alarm(0) + return result # Configure logging for the autotuner module # TODO: Consider creating a common logger in utils @@ -376,12 +392,8 @@ def device_wrapper(func, device, *config_arg): # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution # Because tma init may behave strangely with one thread # latency, ref_latency = target_fn(jit_context) - benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - future = benchmark_executor.submit( - functools.partial(device_wrapper, target_fn, torch.cuda.current_device()), - jit_context) - latency, ref_latency = future.result(timeout=timeout) - except concurrent.futures.TimeoutError: + latency, ref_latency = run_with_timeout(target_fn, timeout, jit_context) + except TimeoutException: logger.info( f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" ) From 241709410cf4b9e33fb3dff71a9001922c5c1ce8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 14 May 2025 19:43:53 +0800 Subject: [PATCH 2/6] Enhance the autotuner module by adding a timeout handler for function execution, improving robustness in handling long-running tasks. This change includes the introduction of a custom TimeoutException and updates to the run_with_timeout function for better signal management. --- tilelang/autotuner/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 18d6f5323..716d0208f 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -21,12 +21,15 @@ import sys import signal + class TimeoutException(Exception): pass + def timeout_handler(signum, frame): raise TimeoutException() + def run_with_timeout(func, timeout, *args, **kwargs): signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(timeout) @@ -36,6 +39,7 @@ def run_with_timeout(func, timeout, *args, **kwargs): signal.alarm(0) return result + # Configure logging for the autotuner module # TODO: Consider creating a common logger in utils logger = logging.getLogger(__name__) From 6eeeeaa2741cd7b02133993d6b857fa203cef002 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 16 May 2025 15:29:45 +0800 Subject: [PATCH 3/6] Add merge shared memory allocations pass and related configurations - Introduced a new pass for merging shared memory allocations in GPU kernels, allowing for more efficient memory usage. - Registered configuration options for debugging and controlling the merging behavior. - Updated relevant files to integrate the new pass into the TileLang engine and transform modules. - Adjusted import paths and added documentation for the new functionality. --- src/op/builtin.cc | 1 + src/op/builtin.h | 3 +- .../merge_shared_memory_allocations.cc | 828 ++++++++++++++++++ tilelang/engine/phase.py | 2 +- tilelang/transform/__init__.py | 11 + tilelang/transform/pass_config.py | 3 + 6 files changed, 846 insertions(+), 2 deletions(-) create mode 100644 src/transform/merge_shared_memory_allocations.cc diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 802877ca6..157738a4d 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -19,6 +19,7 @@ namespace tvm { namespace tl { +TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); diff --git a/src/op/builtin.h b/src/op/builtin.h index 885175610..e059ab8fa 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -15,7 +15,8 @@ namespace tvm { namespace tl { - +static constexpr const char *kDebugMergeSharedMemoryAllocations = + "tl.debug_merge_shared_memory_allocations"; static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kDisableSafeMemoryLegalize = "tl.disable_safe_memory_legalize"; diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc new file mode 100644 index 000000000..a3670d62e --- /dev/null +++ b/src/transform/merge_shared_memory_allocations.cc @@ -0,0 +1,828 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file merge_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic or static shared + * memory allocation. This pass merges multiple TIR-level dynamic or static + * shared memory allocations into one allocation. + */ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../op/builtin.h" +#include "runtime/thread_storage_scope.h" +#include "support/arena.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +using runtime::StorageRank; +using runtime::StorageScope; + +static bool IsDynamicSharedMemory(Var buffer_var) { + StorageScope storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn"; +} + +static bool IsStaticSharedMemory(Var buffer_var) { + StorageScope storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ""; +} + +/*! + * \brief collect the mapping from the buffer var to its allocate + */ +class AllocateCollector : public StmtExprVisitor { +public: + void VisitStmt_(const AllocateNode *op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_[op->buffer_var.get()] = op; + } else if (IsStaticSharedMemory(op->buffer_var)) { + static_shmem_allocs_[op->buffer_var.get()] = op; + } + StmtExprVisitor::VisitStmt_(op); + } + // The dynamic mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; + // The static mapping from the original buffer var to its allocate + std::unordered_map + static_shmem_allocs_; +}; + +// Find a linear pattern of storage access +// Used for liveness analysis. +// "linear" means fitting a complex access pattern into an array of StmtEntry +// +// Define "scope" as the body of For/thread_launch/IfThenElse +// Composite scopes(loop/thread_launch/IfThen) is represented by three +// StmtEntry: before_scope -> scope_body -> after_scope +// +// This pass tries to detect last point that we need to keep memory +// alive under the same scope as Allocate. +// The storage need to be kept alive between Allocate and last access. +// The free point is only inserted at the same scope of Allocate. +// +class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { +public: + explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true, + bool verbose = false) + : is_dynamic_(is_dynamic), verbose_(verbose) {} + /*! \brief record the touch list of statement. */ + struct StmtEntry { + // The statement + const Object *stmt; + // The index in the linear_seq_ to point to end of the nested scope. + // This is only set to non-zero if stmt is a nested scope. + // if offset > 0, means this is the begin, the end entry is current_index + + // offset if offset < 0, means this is the end, the begin entry is + // current_index + offset + int64_t scope_pair_offset{0}; + // The buffer variables this statement touched. + std::vector touched; + }; + // The scope of each allocation + struct AllocEntry { + // the level in the scope stack + size_t level{0}; + // allocation stmt + const AllocateNode *alloc{nullptr}; + }; + + void VisitStmt_(const AllocateNode *op) final { + size_t level = scope_.size(); + const VarNode *buf = op->buffer_var.get(); + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + // Add write access. + const VarNode *buf = op->buffer->data.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + if (IsAppropriateSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + + void VisitStmt_(const EvaluateNode *op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + + void VisitExpr_(const BufferLoadNode *op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + const VarNode *buf = op->buffer->data.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) + << "Load memory in places other than store."; + if (IsAppropriateSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + } + + void VisitExpr_(const VarNode *buf) final { + // Directly reference to the variable count as a read. + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + if (IsAppropriateSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + } + + template void VisitNewScope(const T *op) { + scope_.push_back(StmtEntry()); + StmtEntry e; + e.stmt = op; + int64_t begin_index = static_cast(linear_seq_.size()); + // before scope. + linear_seq_.push_back(e); + StmtExprVisitor::VisitStmt_(op); + // after scope. + e.touched = std::move(scope_.back().touched); + scope_.pop_back(); + int64_t end_index = static_cast(linear_seq_.size()); + ICHECK_GT(end_index, begin_index); + e.scope_pair_offset = begin_index - end_index; + linear_seq_.push_back(e); + // record the pointer to end index. + ICHECK_NE(end_index, 0U); + linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; + } + + void VisitStmt_(const AttrStmtNode *op) final { + // Only record the outer most thread extent. + if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) { + in_thread_env_ = true; + VisitNewScope(op); + in_thread_env_ = false; + } else if (op->attr_key == tir::attr::extern_scope) { + VisitNewScope(op); + } else if (op->attr_key == tir::attr::virtual_thread) { + VisitNewScope(op); + } else if (op->attr_key == "kWarpSpecializationScope") { + IfThenElse body = Downcast(op->body); + this->VisitStmt(body->then_case); + this->VisitStmt(body->else_case.value()); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + + void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const ForNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const AssertStmtNode *op) final { VisitNewScope(op); } + + // linearized access sequence. + std::vector linear_seq_; + // The storage scope of each buffer + std::unordered_map alloc_info_; + +private: + // Wrapper function to determine if the shared memory allocation for a + // variable is appropriate. + bool IsAppropriateSharedMemory(const Var &var) { + return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); + } + // Whether do dyanmic analysis. + bool is_dynamic_{true}; + // Whether do verbose logging. + bool verbose_{false}; + // Whether already in thread env. + bool in_thread_env_{false}; + // The scope stack. + std::vector scope_; +}; + +/*! + * \brief merge the buffers whose live range has no intersection and rewrite the + * body + */ +class SharedMemoryRewriter : public StmtExprMutator { +public: + explicit SharedMemoryRewriter( + const std::unordered_map + &shmem_allocs, + bool is_dynamic = true, bool verbose = false) + : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs}, verbose_{ + verbose} { + if (!is_dynamic) { + merged_buf_var_ = + Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared")); + } + } + + /*! + * \brief plan the memory reuse for all the buffer allocated in the statement + * \param stmt the statement + */ + void PlanReuse(const Stmt &stmt, bool is_dynamic = true, + bool verbose = false) { + SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose); + finder(stmt); + this->LivenessAnalysis(finder.linear_seq_); + this->PlanMemory(finder.linear_seq_); + } + +private: + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent && !allocated_) { + // Allocate one dynamic shared memory allocation at the beginning of + // thread scope + int max_layer_num = 0; + std::vector all_entry; + for (const auto &e : const_free_map_) { + all_entry.push_back(e.second); + } + for (const StorageEntry *e : sym_free_list_) { + all_entry.push_back(e); + } + for (const StorageEntry *e : all_entry) { + max_layer_num = + std::max(max_layer_num, static_cast(e->allocs.size())); + } + // calculate align for each layer of each storage entry. + std::vector align(max_layer_num, 0); + for (const StorageEntry *e : all_entry) { + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + for (const VarNode *buffer : e->allocs[i]) { + const AllocateNode *alloc = shmem_allocs_[buffer]; + align[i] = + std::max(align[i], alloc->dtype.bytes() * alloc->dtype.lanes()); + } + } + } + // calculate offset for each buffer based on the align of each layer + for (const StorageEntry *e : all_entry) { + PrimExpr max_inner_offset = 0; + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + PrimExpr inner_offset = 0; + for (const VarNode *buffer : e->allocs[i]) { + const AllocateNode *alloc = shmem_allocs_[buffer]; + buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset; + inner_offset += + alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes(); + inner_offset += + indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]); + } + max_inner_offset = max(max_inner_offset, inner_offset); + } + merged_alloc_size_ += max_inner_offset; + } + + if (verbose_) { + + LOG(DEBUG) << "Memory Allocation Plan for " + << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; + LOG(DEBUG) << " Merged Buffer Name: " << merged_buf_var_->name_hint; + LOG(DEBUG) << " Total Merged Size: " << merged_alloc_size_ << " bytes"; + LOG(DEBUG) << " Individual Buffer Allocations:"; + for (const auto &pair : buffer_byte_offsets_) { + const VarNode *buffer_var_node = pair.first; + PrimExpr byte_offset = pair.second; + auto alloc_it = shmem_allocs_.find(buffer_var_node); + if (alloc_it != shmem_allocs_.end()) { + const AllocateNode *alloc = alloc_it->second; + PrimExpr buffer_size_bytes = + alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes(); + LOG(DEBUG) << " Buffer: " << buffer_var_node->name_hint + << " (Type: " << alloc->dtype << ")" + << ", Start Offset: " << byte_offset + << ", Size: " << buffer_size_bytes << " bytes" + << ", End Offset: " + << (byte_offset + buffer_size_bytes - 1); + } else { + LOG(DEBUG) << " Buffer: " << buffer_var_node->name_hint + << ", Start Offset: " << byte_offset + << " (Original allocation info not found)"; + } + } + LOG(DEBUG) << "End of Memory Allocation Plan."; + } + + allocated_ = true; + Allocate new_body(merged_buf_var_, DataType::UInt(8), + {merged_alloc_size_}, const_true(), + StmtExprMutator::VisitStmt(op->body)); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateNode *op) final { + if (IsAppropriateSharedMemory(op->buffer_var)) { + return StmtExprMutator::VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + auto new_buf = GetUpdatedBuffer(node->buffer); + if (!new_buf.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buf; + } + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template Node VisitBufferAccess(Node node) { + if (IsAppropriateSharedMemory(node->buffer->data)) { + ICHECK_EQ(node->indices.size(), 1) + << "MergeSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + Array indices = { + node->indices[0] + + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetUpdatedBuffer(node->buffer); + writer->indices = indices; + } + + return node; + } + + Buffer GetUpdatedBuffer(Buffer buffer) { + auto key = buffer.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + return it->second; + } + + if (IsAppropriateSharedMemory(buffer->data)) { + ICHECK_EQ(buffer->shape.size(), 1) + << "Buffer " << buffer << " has shape " << buffer->shape << ". " + << "MergeSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + auto writer = buffer.CopyOnWrite(); + writer->data = merged_buf_var_; + } + + buffer_remap_[key] = buffer; + return buffer; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + Var buffer = Downcast(op->args[1]); + if (!IsAppropriateSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); + return Call(op->dtype, op->op, + {op->args[0], merged_buf_var_, extra_offset + offset, extent, + op->args[4]}); + } else if (op->op.same_as(builtin::ptx_cp_async())) { + ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); + DataType dtype = op->dtype; + Var buffer = Downcast(op->args[0]); + if (!IsAppropriateSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + PrimExpr offset = this->VisitExpr(op->args[1]); + // the dst shared memory is a byte buffer generated by merging shared + // memory. we need to multiply the offset index by the byte size of the + // original value dtype, to get the correct offset of merged shared + // buffer. + int index_factor = dtype.bytes(); + if (op->args.size() == 5) + return Call(dtype, op->op, + {merged_buf_var_, + mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4]}); + else + return Call(dtype, op->op, + {merged_buf_var_, + mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4], op->args[5]}); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + + PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { + auto it = buffer_byte_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_byte_offsets_.end()) + << "buffer_var = " << buffer_var->name_hint << ", dtype = " << dtype; + return indexdiv(it->second, dtype.bytes() * dtype.lanes()); + } + + // Wrapper function to determine if the shared memory allocation for a + // variable is appropriate. + bool IsAppropriateSharedMemory(const Var &var) { + return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); + } + + using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; + struct StorageEntry { + // The constant size of the buffer in bits, only used if it is constant + uint64_t const_nbits{0}; + // Allocs that shares this entry. + // The inner vector means a "layer" + // For example, it we need to allocate C in the memory of A and B: + // | A: 4096 bytes | B: 4096 bytes | + // | C: 8192 bytes | + // Then the allocs = {{A, B}, {C}} + std::vector> allocs; + }; + + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + + /*! + * \brief Liveness analysis to find gen and kill point of each variable. + * \param seq the linear pattern of storage access + */ + void LivenessAnalysis(const std::vector &seq) { + // find kill point, do a reverse linear scan. + std::unordered_set touched; + for (size_t i = seq.size(); i != 0; --i) { + const StmtEntry &s = seq[i - 1]; + for (const VarNode *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].kill.push_back(buffer); + } + } + } + // find gen point, do forward scan + touched.clear(); + for (size_t i = 0; i < seq.size(); ++i) { + int64_t offset = seq[i].scope_pair_offset; + if (offset < 0) + continue; + const StmtEntry &s = seq[i + offset]; + for (const VarNode *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].gen.push_back(buffer); + } + } + } + + if (verbose_) { + LOG(DEBUG) << "Liveness Analysis Results for " + << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; + for (const auto &pair : event_map_) { + const Object *stmt_obj = pair.first; + const EventEntry &entry = pair.second; + + if (entry.gen.empty() && entry.kill.empty()) { + continue; // Skip statements with no gen/kill events for brevity + } + + LOG(DEBUG) << " Statement: " << stmt_obj->GetTypeKey(); + + std::stringstream gen_vars_ss; + bool x_generated = false; + for (const VarNode *var : entry.gen) { + gen_vars_ss << var->name_hint << " "; + if (var->name_hint == "x") { + x_generated = true; + } + } + if (!entry.gen.empty()) { + std::string gen_log_msg = " GEN: " + gen_vars_ss.str(); + if (x_generated) { + gen_log_msg += " <-- Buffer 'x' generated"; + } + LOG(DEBUG) << gen_log_msg; + } + + std::stringstream kill_vars_ss; + bool x_killed = false; + for (const VarNode *var : entry.kill) { + kill_vars_ss << var->name_hint << " "; + if (var->name_hint == "x") { + x_killed = true; + } + } + if (!entry.kill.empty()) { + std::string kill_log_msg = " KILL: " + kill_vars_ss.str(); + if (x_killed) { + kill_log_msg += " <-- Buffer 'x' killed"; + } + LOG(DEBUG) << kill_log_msg; + } + } + LOG(DEBUG) << "End of Liveness Analysis Results."; + } + } + + /*! + * \brief Memory plan algorithm + * \param seq the linear pattern of storage access + * \param alloc_info + */ + void PlanMemory(const std::vector &seq) { + std::unordered_set inplace_flag; + + for (size_t i = 0; i < seq.size(); ++i) { + auto it = event_map_.find(seq[i].stmt); + // scope_pair_offset <= 0 means it is either + // - leaf stmt(offset = 0) + // - end of scope(offset < 0) + // In both cases, we need to handle the kill event correctly + auto is_leaf_alloc = [&](const VarNode *var) { + return seq[i].scope_pair_offset == 0 && + std::find(it->second.gen.begin(), it->second.gen.end(), var) != + it->second.gen.end(); + }; + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode *var : it->second.kill) { + if (!is_leaf_alloc(var)) + this->Free(var); + } + } + // scope_pair_offset >= 0 means it is either + // - leaf stmt(offset = 0) + // - beginning of scope(offset < 0) + // In both cases, we need to handle the gen event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { + for (const VarNode *var : it->second.gen) { + ICHECK(shmem_allocs_.count(var)); + const AllocateNode *alloc = shmem_allocs_[var]; + StorageEntry *dst_entry = FindAlloc(alloc); + alloc_map_[var] = dst_entry; + } + } + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode *var : it->second.kill) { + if (is_leaf_alloc(var)) + this->Free(var); + } + } + } + } + /*! + * \brief Allocate new storage entry. + * \param op the allocate node + * \param the size of the allocation in bits + * \return the new storage entry + */ + StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) { + ICHECK(op != nullptr); + // Re-use not successful, allocate a new buffer. + StorageEntry *entry = arena_.make(); + entry->allocs.push_back({op->buffer_var.get()}); + entry->const_nbits = const_nbits; + return entry; + } + /*! + * \brief find the storage entry in the free list for the allocate + * \param op the allocate node + * \return the storage entry + */ + StorageEntry *FindAlloc(const AllocateNode *op) { + ICHECK(op != nullptr); + // skip plan for local variable, + // compiler can do a better job with register allocation. + const uint64_t match_range = 16; + uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); + uint64_t const_nbits = + static_cast(op->ConstantAllocationSize() * op_elem_bits); + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + if (const_nbits > 0 && const_nbits <= 32) { + return NewAlloc(op, const_nbits); + } + + if (const_nbits != 0) { + // constant allocation. + auto begin = const_free_map_.lower_bound(0); + auto mid = const_free_map_.lower_bound(const_nbits); + auto end = const_free_map_.upper_bound(const_nbits * match_range); + // Start looking at the buffer that is bigger than the required size + // first. If we find one, directly allocate the buffer in its location and + // remove its entry in the free list + for (auto it = mid; it != end; ++it) { + StorageEntry *e = it->second; + e->const_nbits = std::max(const_nbits, e->const_nbits); + const_free_map_.erase(it); + it->second->allocs.push_back({op->buffer_var.get()}); + return e; + } + // Then start looking at smaller buffers. + // Keep collecting the buffer until the sum of their size exceeds the + // buffer to allocate and finally free all these entry in the free list + std::vector::iterator> delete_it; + // the alloc list for the new entry + std::vector> reuse_allocs; + uint64_t mem_ct = 0; + for (auto it = mid; it != begin;) { + --it; + delete_it.push_back(it); + mem_ct += it->second->const_nbits; + int n = it->second->allocs.size(); + if (n > static_cast(reuse_allocs.size())) { + reuse_allocs.resize(n, {}); + } + for (int i = 0; i < n; i++) { + for (const VarNode *alloc : it->second->allocs[i]) { + reuse_allocs[i].push_back(alloc); + } + } + if (mem_ct >= const_nbits) { + break; + } + } + reuse_allocs.push_back({op->buffer_var.get()}); + if (mem_ct != 0) { + StorageEntry *e = arena_.make(); + e->const_nbits = std::max(const_nbits, mem_ct); + e->allocs = reuse_allocs; + for (auto it : delete_it) { + const_free_map_.erase(it); + } + return e; + } + } else { + // if its symbolic allocation, just arbitrarily choose one entry to fit in + // because we don't know its actual size + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { + StorageEntry *e = *it; + sym_free_list_.erase(it); + return e; + } + } + return NewAlloc(op, const_nbits); + } + + /*! + * \brief add the storage entry to the buffer var into the free list. + * \param var the buffer var + */ + void Free(const VarNode *var) { + auto it = alloc_map_.find(var); + ICHECK(it != alloc_map_.end()); + StorageEntry *e = it->second; + ICHECK_NE(e->allocs.size(), 0U); + + // disable reuse of small arrays + if (e->const_nbits > 0 && e->const_nbits <= 32) + return; + + // normal free. + if (e->const_nbits != 0) { + const_free_map_.insert({e->const_nbits, e}); + } else { + sym_free_list_.push_back(e); + } + } + // Wheather enable dyanmic analysis. + bool is_dynamic_{true}; + // Whether enable verbose logging. + bool verbose_{false}; + // The var for the merged buffer + Var merged_buf_var_{"buf_dyn_shmem", + PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; + // The mapping from the original buffer var to its allocate + std::unordered_map shmem_allocs_; + // The size of the merged buffer + PrimExpr merged_alloc_size_{0}; + // The mapping from the original buffer var to its offset in the merged buffer + std::unordered_map buffer_byte_offsets_; + // The mapping from the original buffer objects to their location in the + // merged buffer. + std::unordered_map buffer_remap_; + // The flag indicating whether the merged buffer has been allocated + bool allocated_{false}; + // Locations of free ops. + std::unordered_map event_map_; + // constant size free map. + std::multimap const_free_map_; + // symbolic free list, for non constant items. + std::list sym_free_list_; + // The allocation assign map + std::unordered_map alloc_map_; + /*! \brief allocator of all the StorageEntry*/ + support::Arena arena_; +}; + +Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem, + bool verbose = false) { + AllocateCollector collector; + collector(stmt); + if (collector.dyn_shmem_allocs_.size() > 1) { + SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose); + rewriter.PlanReuse(stmt); + stmt = rewriter(std::move(stmt)); + } + if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) { + SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false, + verbose); + rewriter.PlanReuse(stmt, false); + stmt = rewriter(std::move(stmt)); + } + return stmt; +} + +using namespace tir::transform; + +namespace transform { + +Pass MergeSharedMemoryAllocations() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + bool merge_static_smem = + ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + bool debug_merge_shared_memory_allocations = + ctx->GetConfig(kDebugMergeSharedMemoryAllocations, Bool(false)) + .value(); + auto *n = f.CopyOnWrite(); + n->body = + tl::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem, + debug_merge_shared_memory_allocations); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations", + {}); +} + +TVM_REGISTER_GLOBAL("tl.transform.MergeSharedMemoryAllocations") + .set_body_typed(MergeSharedMemoryAllocations); + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 21c9ee0c6..96d5fa875 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -130,7 +130,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) + mod = tilelang.transform.MergeSharedMemoryAllocations()(mod) mod = tilelang.transform.MakePackedAPI()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 2467ac581..9efe5836e 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -333,3 +333,14 @@ def EliminateStorageSyncForMBarrier(): """EliminateStorageSyncForMBarrier """ return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore + + +def MergeSharedMemoryAllocations(): + """MergeSharedMemoryAllocations + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeSharedMemoryAllocations() # type: ignore diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 7ae6ca4ee..1ad7dd706 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -29,6 +29,9 @@ class PassConfigKey(str, Enum): TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" """Disable safe memory access optimization. Default: False""" + TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations" + """Enable debug information for merge shared memory allocations. Default: False""" + # TIR related configs TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" From aa8ab88307f9190f63fd1c40205d796e870a0437 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 16 May 2025 12:59:27 +0000 Subject: [PATCH 4/6] Reduce num_stages parameter in GEMM functions from 3 to 1 for improved performance in test_tilelang_kernel_gemm.py --- testing/python/kernel/test_tilelang_kernel_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/python/kernel/test_tilelang_kernel_gemm.py b/testing/python/kernel/test_tilelang_kernel_gemm.py index 8c8523719..4e5259fe4 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -357,7 +357,7 @@ def run_gemm_sr( block_M, block_N, block_K, - num_stages=3, + num_stages=1, num_threads=128, ): program = matmul_sr( @@ -473,7 +473,7 @@ def run_gemm_rs( block_M, block_N, block_K, - num_stages=3, + num_stages=1, num_threads=128, ): program = matmul_rs( From 73fb9a0aff0f3b046249b999d0c29f958d6e48bb Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 17 May 2025 11:07:31 +0000 Subject: [PATCH 5/6] Update Copy type in OperandTraits for GEMM templates to use conditional selection based on num_warp_n. This change enhances memory access patterns for different configurations in CUDA kernels. --- src/tl_templates/cuda/gemm_sm80.h | 6 ++++-- src/tl_templates/cuda/gemm_sm89.h | 6 ++++-- src/tl_templates/cuda/gemm_sm90.h | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/tl_templates/cuda/gemm_sm80.h b/src/tl_templates/cuda/gemm_sm80.h index a79a5ccf1..55d18c1b1 100644 --- a/src/tl_templates/cuda/gemm_sm80.h +++ b/src/tl_templates/cuda/gemm_sm80.h @@ -98,7 +98,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template @@ -108,7 +109,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template diff --git a/src/tl_templates/cuda/gemm_sm89.h b/src/tl_templates/cuda/gemm_sm89.h index 4f7058896..8e326f86d 100644 --- a/src/tl_templates/cuda/gemm_sm89.h +++ b/src/tl_templates/cuda/gemm_sm89.h @@ -201,7 +201,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template @@ -211,7 +212,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template diff --git a/src/tl_templates/cuda/gemm_sm90.h b/src/tl_templates/cuda/gemm_sm90.h index 313793cd2..bf55499c8 100644 --- a/src/tl_templates/cuda/gemm_sm90.h +++ b/src/tl_templates/cuda/gemm_sm90.h @@ -255,7 +255,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template @@ -265,7 +266,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template From 08f2418ed5ee730ddb01533c6d7909bd93743176 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 17 May 2025 11:08:24 +0000 Subject: [PATCH 6/6] lint fix --- examples/deepseek_mla/test_example_mla_decode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index d011526b3..9cde90b83 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -8,6 +8,7 @@ @tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mla_decode(): with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]): example_mla_decode.main()