-
Notifications
You must be signed in to change notification settings - Fork 332
[Language] Support T.annotate_l2_hit_ratio via cudaStreamSetAttribute
#539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b6d0684
8becbe3
bce403f
d62f966
46d8a0f
f6787cb
4973c7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| // Copyright (c) Tile-AI Corporation. | ||
| // Licensed under the MIT License. | ||
| /*! | ||
| * \file lower_l2_persistent_annotation.cc | ||
| * \brief Lower L2 persistent annotation | ||
| */ | ||
|
|
||
| #include <tvm/tir/analysis.h> | ||
| #include <tvm/tir/builtin.h> | ||
| #include <tvm/tir/stmt_functor.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| #include "../op/builtin.h" | ||
| #include "../op/bulk_copy.h" | ||
| #include "../runtime/runtime.h" | ||
|
|
||
| namespace tvm { | ||
| namespace tl { | ||
|
|
||
| namespace attr { | ||
| // BlockAttr, Containing the layout for all the buffers in the block | ||
| constexpr const char *kL2RatioMap = "l2_hit_ratio_map"; | ||
| constexpr const char *kL2PersistentMap = "l2_persistent_map"; | ||
| } // namespace attr | ||
|
|
||
| using namespace tir; | ||
|
|
||
| class LowerL2Persistent : public StmtExprMutator { | ||
| public: | ||
| static PrimFunc Substitute(PrimFunc &f) { | ||
| PrimFuncNode *fptr = f.CopyOnWrite(); | ||
| LowerL2Persistent substituter; | ||
| // Trace the buffer map for tvm_access_ptr | ||
| substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end()); | ||
| for (const auto &[_, buffer] : f->buffer_map) { | ||
| substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); | ||
| } | ||
| fptr->body = substituter.VisitStmt(f->body); | ||
| Map<String, Array<PrimExpr>> init_l2_persistent_map; | ||
| for (auto [buffer, hit_ratio] : substituter.hit_ratio_map_) { | ||
| Array<PrimExpr> l2_persistent_arguments; | ||
| // Argument 0: hit ratio | ||
| // Argument 1: size in bytes | ||
| l2_persistent_arguments.push_back(hit_ratio); | ||
| PrimExpr size_in_bytes = IntImm(DataType::Int(64), buffer->dtype.bytes()); | ||
| for (auto dim : buffer->shape) { | ||
| size_in_bytes = size_in_bytes * dim; | ||
| } | ||
| l2_persistent_arguments.push_back(size_in_bytes); | ||
| init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments); | ||
| } | ||
| if (init_l2_persistent_map.size() > 0) { | ||
| f = WithAttr(std::move(f), attr::kL2PersistentMap, | ||
| init_l2_persistent_map); | ||
| } | ||
| return f; | ||
| } | ||
|
|
||
| Stmt VisitStmt_(const BlockNode *op) final { | ||
| // Record the mapping from buffer data var to buffer for later lookup | ||
| for (auto buffer : op->alloc_buffers) { | ||
| buffer_map_.insert({buffer->data, buffer}); | ||
| } | ||
| for (auto match_buffer : op->match_buffers) { | ||
| buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); | ||
| } | ||
| for (auto buffer : op->alloc_buffers) { | ||
| buffer_data_to_buffer_.Set(buffer->data, buffer); | ||
| } | ||
|
|
||
| if (op->annotations.count(attr::kL2RatioMap)) { | ||
| auto hit_ratio_map = op->annotations.at(attr::kL2RatioMap) | ||
| .as<Map<Var, FloatImm>>() | ||
| .value(); | ||
| for (auto [buffer_var, hit_ratio] : hit_ratio_map) { | ||
| Buffer buffer = buffer_data_to_buffer_.at(buffer_var); | ||
| hit_ratio_map_.Set(buffer, hit_ratio); | ||
| } | ||
| } | ||
| auto block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); | ||
| auto block_ptr = block.CopyOnWrite(); | ||
| block_ptr->annotations.erase(attr::kL2RatioMap); | ||
| return block; | ||
| } | ||
|
|
||
| private: | ||
| // Mapping from data Var of a Buffer to Buffer, for lookup | ||
| Map<Var, Buffer> buffer_data_to_buffer_; | ||
| std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_; | ||
| Map<Buffer, FloatImm> hit_ratio_map_; | ||
| LowerL2Persistent() = default; | ||
| }; | ||
|
|
||
| using namespace tir::transform; | ||
|
|
||
| tvm::transform::Pass LowerL2Persistent() { | ||
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
| return LowerL2Persistent::Substitute(f); | ||
| }; | ||
| return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {}); | ||
| } | ||
|
|
||
| TVM_REGISTER_GLOBAL("tl.transform.LowerL2Persistent") | ||
| .set_body_typed(LowerL2Persistent); | ||
|
|
||
| } // namespace tl | ||
| } // namespace tvm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -668,6 +668,7 @@ class TLVectorizer : public StmtMutator, | |
| } | ||
| } | ||
| } | ||
|
|
||
| // Allocate | ||
| Stmt VisitStmt_(const AllocateNode *op) final { | ||
| // Mutate the condition | ||
|
|
@@ -678,32 +679,7 @@ class TLVectorizer : public StmtMutator, | |
| return Scalarize(GetRef<Stmt>(op)); | ||
| } | ||
|
|
||
| // Mutate the extents | ||
| Array<PrimExpr> extents; | ||
| for (const auto &extent : op->extents) { | ||
| PrimExpr new_ext = this->VisitExpr(extent); | ||
| if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { | ||
| LOG(WARNING) << "Cannot handle vector extent in alloc of " | ||
| << op->buffer_var->name_hint; | ||
| return Scalarize(GetRef<Stmt>(op)); | ||
| } | ||
| extents.push_back(new_ext); | ||
| } | ||
|
|
||
| // TODO(Lunderberg): Move this pass to be prior to | ||
| // StorageFlatten/FlattenBuffer. That will allow this pass to be | ||
| // implemented as adding a new buffer dimension, which is later | ||
| // flattened. | ||
|
|
||
| // Extend the least significant dimension by a factor of | ||
| // var_lanes_. Typically, this will be a 1-d index into a flat | ||
| // memory space. | ||
| extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); | ||
| // Rewrite access to the buffer in the body. | ||
| Stmt body = | ||
| TLVecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); | ||
| body = this->VisitStmt(body); | ||
| return Allocate(op->buffer_var, op->dtype, extents, condition, body); | ||
| return StmtMutator::VisitStmt_(op); | ||
|
Comment on lines
681
to
+682
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for handling This is a substantial change. Could you please elaborate on the rationale behind this modification?
Without understanding the reasoning, it's hard to assess the impact of this change. The |
||
| } | ||
|
|
||
| // scalarize the statment | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -49,6 +49,29 @@ | |||||||||
| }} | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| L2_PERSISTENT_MAP_CREATE_HANDLE = """ | ||||||||||
| \tcudaStreamAttrValue stream_attribute; | ||||||||||
| \tsize_t init_persisting_l2_cache_size; | ||||||||||
| \tcudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize); | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| L2_PERSISTENT_MAP_INIT_FUNC = """ | ||||||||||
| \tstream_attribute.accessPolicyWindow.hitRatio = {1}; | ||||||||||
| \tstream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting; | ||||||||||
| \tstream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming; | ||||||||||
| \tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {3}); | ||||||||||
| \tstream_attribute.accessPolicyWindow.base_ptr = (void*)({0}); | ||||||||||
| \tstream_attribute.accessPolicyWindow.num_bytes = {3}; | ||||||||||
| \tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute); | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| L2_PERSISTENT_MAP_RESET_HANDLE = """ | ||||||||||
| \tstream_attribute.accessPolicyWindow.num_bytes = 0; | ||||||||||
| \tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute); | ||||||||||
| \tcudaCtxResetPersistingL2Cache(); | ||||||||||
| \tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size); | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| TMA_DESC_INIT_FUNC = """ | ||||||||||
| \tCUtensorMap {0}; | ||||||||||
| \tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1}; | ||||||||||
|
|
@@ -127,6 +150,7 @@ def __init__(self, | |||||||||
| self.block_info: Union[List[int], Dict] = [1, 1, 1] | ||||||||||
| self.grid_info: Union[List[int], Dict] = [1, 1, 1] | ||||||||||
| self.tma_descriptor_args: Optional[Dict] = None | ||||||||||
| self.l2_persistent_map: Optional[Dict[str, Dict]] = {} | ||||||||||
| self.parse_source_information() | ||||||||||
| self.srcpath: Optional[str] = None | ||||||||||
| self.libpath: Optional[str] = None | ||||||||||
|
|
@@ -196,7 +220,15 @@ def legalize_c(p): | |||||||||
| p = int(p) | ||||||||||
| return str(p).replace("//", "/") | ||||||||||
|
|
||||||||||
| has_l2_persistent_map = False | ||||||||||
| for function_name, _ in function_informations.items(): | ||||||||||
| if function_name in self.l2_persistent_map: | ||||||||||
| has_l2_persistent_map = True | ||||||||||
| break | ||||||||||
|
|
||||||||||
| kernel_launch_code = """""" | ||||||||||
| if has_l2_persistent_map: | ||||||||||
| kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE | ||||||||||
| desc_name_map: Dict[str, str] = {} | ||||||||||
| for function_name, function_info in function_informations.items(): | ||||||||||
| block_info = function_info["block_info"] | ||||||||||
|
|
@@ -221,16 +253,37 @@ def legalize_c(p): | |||||||||
| grid_str = "dim3({}, {}, {})".format( | ||||||||||
| legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) | ||||||||||
| smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf | ||||||||||
| init_l2_persistent_map = self.generate_l2_persistent_map(function_name) | ||||||||||
| kernel_launch_code += init_l2_persistent_map | ||||||||||
| kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format( | ||||||||||
| function_name, grid_str, block_str, smem_str, call_args) | ||||||||||
| kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name) | ||||||||||
| if has_l2_persistent_map: | ||||||||||
| kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE | ||||||||||
|
|
||||||||||
| kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code | ||||||||||
| init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map) | ||||||||||
| kernel_launch_code = init_tma_descriptor_args + kernel_launch_code | ||||||||||
|
|
||||||||||
| # Wrap the kernel dispatch logic in an external C function | ||||||||||
| host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code) | ||||||||||
| return host_func | ||||||||||
|
|
||||||||||
| def generate_l2_persistent_map(self, function_name: str) -> str: | ||||||||||
| if function_name not in self.l2_persistent_map: | ||||||||||
| return "" | ||||||||||
| init_l2_persistent_map = "" | ||||||||||
| for buffer_name, (hit_ratio, | ||||||||||
| size_in_bytes) in self.l2_persistent_map[function_name].items(): | ||||||||||
| # get persisting_l2_cache_max_size | ||||||||||
| from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import While this might not cause major issues in a code generation context if the method isn't called extremely frequently in a tight loop, it's generally better for clarity, testability (mocking), and avoiding potential repeated import overhead to place imports at the module level. Could this import be moved to the top of the
Suggested change
|
||||||||||
| persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() | ||||||||||
| num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) | ||||||||||
|
|
||||||||||
| init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format( | ||||||||||
| buffer_name, float(hit_ratio), size_in_bytes, num_bytes) | ||||||||||
|
|
||||||||||
| return init_l2_persistent_map | ||||||||||
|
|
||||||||||
| def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: | ||||||||||
| tma_descripter_init = "" | ||||||||||
| if self.tma_descriptor_args is None: | ||||||||||
|
|
@@ -263,10 +316,19 @@ def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: | |||||||||
| box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||||||||||
| element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] | ||||||||||
|
|
||||||||||
| global_dim = [str(i) for i in global_dim] | ||||||||||
| global_stride = [str(i) for i in global_stride] | ||||||||||
| box_dim = [str(i) for i in box_dim] | ||||||||||
| element_strides = [str(i) for i in element_strides] | ||||||||||
| def legalize_c2s(p): | ||||||||||
| # Convert TIR expressions to legal C expressions | ||||||||||
| # Directly convert to string since the special case handling | ||||||||||
| # does not alter the string representation for `tvm.tir.Var` and `IntImm`. | ||||||||||
| # Replace Python's floor division operator with C's division operator | ||||||||||
| if isinstance(p, tvm.tir.IntImm): | ||||||||||
| p = int(p) | ||||||||||
| return str(p) | ||||||||||
|
|
||||||||||
| global_dim = [legalize_c2s(i) for i in global_dim] | ||||||||||
| global_stride = [legalize_c2s(i) for i in global_stride] | ||||||||||
| box_dim = [legalize_c2s(i) for i in box_dim] | ||||||||||
| element_strides = [legalize_c2s(i) for i in element_strides] | ||||||||||
|
|
||||||||||
| # Extract remaining parameters | ||||||||||
| try: | ||||||||||
|
|
@@ -331,6 +393,9 @@ def parse_source_information(self): | |||||||||
| for _, func in self.host_mod.functions.items(): | ||||||||||
| if "tma_descriptor_args" in func.attrs: | ||||||||||
| self.tma_descriptor_args = func.attrs["tma_descriptor_args"] | ||||||||||
| if "l2_persistent_map" in func.attrs: | ||||||||||
| self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"] | ||||||||||
|
Comment on lines
+396
to
+397
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There appears to be a bug in how If The Consider refactoring to populate # Earlier in parse_source_information, after device_mod is available:
self.l2_persistent_map = {}
for g_var, device_func in self.device_mod.functions.items():
func_name_hint = g_var.name_hint
if "l2_persistent_map" in device_func.attrs:
self.l2_persistent_map[func_name_hint] = device_func.attrs["l2_persistent_map"]This change would ensure that L2 persistence settings are correctly associated with their respective device functions. |
||||||||||
|
|
||||||||||
| host_code = str(func) | ||||||||||
| for function_name in function_names: | ||||||||||
| index = host_code.index(f'T.call_packed("{function_name}"') | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way
buffer_map_(anstd::unordered_map) andbuffer_data_to_buffer_(atvm::Map) are populated and used seems a bit complex.Substitute, both maps are populated fromf->buffer_map.VisitStmt_(this section),buffer_map_is further populated fromop->alloc_buffersandop->match_buffers.buffer_data_to_buffer_is populated again fromop->alloc_buffers.hit_ratio_mapusesbuffer_data_to_buffer_.at(buffer_var).Could this logic be simplified? For instance, if
buffer_data_to_buffer_is the primary map for lookups, could its population be consolidated, or isbuffer_map_serving a distinct purpose that's not immediately obvious from its usage here? Clarifying this or streamlining the map management could improve maintainability.