Skip to content

Commit db86ec4

Browse files
authored
[Language] Support T.annotate_l2_hit_ratio via cudaStreamSetAttribute (#539)
* Refactor OptimizeForTarget function by removing redundant buffer allocation step and cleaning up code * Removed the PlanAndUpdateBufferAllocationLocation step from the OptimizeForTarget function to streamline the optimization process. * Cleaned up unnecessary whitespace in the function for improved readability. * Enhanced the overall clarity and maintainability of the code. * Refactor AllocateNode handling in vectorize_loop.cc * Simplified the VisitStmt_ method for AllocateNode by removing the complex extent mutation logic. * Streamlined the allocation process to directly call the base class method, enhancing code clarity and maintainability. * Improved overall readability by eliminating unnecessary comments and code related to extent handling. * Remove `tl_kernel.c` file, eliminating the backward kernel implementation and associated error handling functions. This cleanup enhances code maintainability by removing unused components related to the backward kernel processing. * Add buffer allocation planning step in OptimizeForTarget function * Introduced the PlanAndUpdateBufferAllocationLocation step to the OptimizeForTarget function, enhancing the optimization process. * This addition improves the overall efficiency of buffer allocation during the target optimization phase, ensuring better resource management. * Update submodule TVM to latest commit db50d4e, ensuring alignment with upstream changes. * Add L2 persistent annotation support and related functionality * Introduced a new file `lower_l2_persistent_annotation.cc` to handle the lowering of L2 persistent annotations. * Added functions to annotate L2 hit ratios for buffers, ensuring compatibility with global buffer requirements. * Updated the `LowerAndLegalize` function to include the new L2 persistent map lowering step. * Enhanced CUDA driver with a function to retrieve the maximum size of the persisting L2 cache. * Modified the `TLCUDASourceWrapper` class to integrate L2 persistent map handling during kernel launches. These changes improve the framework's ability to manage L2 cache optimizations, enhancing performance for CUDA applications. * lint fix
1 parent 77c9ab3 commit db86ec4

File tree

10 files changed

+221
-11
lines changed

10 files changed

+221
-11
lines changed

3rdparty/tvm

Submodule tvm updated from c2921fd to db50d4e

src/op/bulk_copy.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
173173
// The first stride element should be 1
174174
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
175175
// Make global stride in bytes
176-
desc.global_stride = desc.global_stride.Map(
177-
[&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
176+
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
177+
return cast(DataType::Int(64), e) * global_tensor->dtype.bytes();
178+
});
178179

179180
// Smem Box
180181
desc.smem_box =
@@ -325,6 +326,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
325326
desc.data_type = to_CUtensorMapDataType(src->dtype);
326327
desc.global_addr = src->data;
327328
desc.global_shape = ReverseArray(src->shape);
329+
328330
if (!src->strides.empty()) {
329331
desc.global_stride = ReverseArray(src->strides);
330332
} else {
@@ -339,8 +341,9 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
339341
// The first stride element should be 1
340342
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
341343
// Make global stride in bytes
342-
desc.global_stride = desc.global_stride.Map(
343-
[&](PrimExpr e) { return e * src->dtype.bytes(); });
344+
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
345+
return cast(DataType::Int(64), e) * src->dtype.bytes();
346+
});
344347
desc.elem_stride = {1, stride, stride, 1};
345348
desc.lower_corner = {-padding, -padding};
346349
desc.upper_corner = {-padding, -padding};
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) Tile-AI Corporation.
2+
// Licensed under the MIT License.
3+
/*!
4+
* \file lower_l2_persistent_annotation.cc
5+
* \brief Lower L2 persistent annotation
6+
*/
7+
8+
#include <tvm/tir/analysis.h>
9+
#include <tvm/tir/builtin.h>
10+
#include <tvm/tir/stmt_functor.h>
11+
#include <tvm/tir/transform.h>
12+
13+
#include "../op/builtin.h"
14+
#include "../op/bulk_copy.h"
15+
#include "../runtime/runtime.h"
16+
17+
namespace tvm {
18+
namespace tl {
19+
20+
namespace attr {
21+
// BlockAttr, Containing the layout for all the buffers in the block
22+
constexpr const char *kL2RatioMap = "l2_hit_ratio_map";
23+
constexpr const char *kL2PersistentMap = "l2_persistent_map";
24+
} // namespace attr
25+
26+
using namespace tir;
27+
28+
class LowerL2Persistent : public StmtExprMutator {
29+
public:
30+
static PrimFunc Substitute(PrimFunc &f) {
31+
PrimFuncNode *fptr = f.CopyOnWrite();
32+
LowerL2Persistent substituter;
33+
// Trace the buffer map for tvm_access_ptr
34+
substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
35+
for (const auto &[_, buffer] : f->buffer_map) {
36+
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
37+
}
38+
fptr->body = substituter.VisitStmt(f->body);
39+
Map<String, Array<PrimExpr>> init_l2_persistent_map;
40+
for (auto [buffer, hit_ratio] : substituter.hit_ratio_map_) {
41+
Array<PrimExpr> l2_persistent_arguments;
42+
// Argument 0: hit ratio
43+
// Argument 1: size in bytes
44+
l2_persistent_arguments.push_back(hit_ratio);
45+
PrimExpr size_in_bytes = IntImm(DataType::Int(64), buffer->dtype.bytes());
46+
for (auto dim : buffer->shape) {
47+
size_in_bytes = size_in_bytes * dim;
48+
}
49+
l2_persistent_arguments.push_back(size_in_bytes);
50+
init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments);
51+
}
52+
if (init_l2_persistent_map.size() > 0) {
53+
f = WithAttr(std::move(f), attr::kL2PersistentMap,
54+
init_l2_persistent_map);
55+
}
56+
return f;
57+
}
58+
59+
Stmt VisitStmt_(const BlockNode *op) final {
60+
// Record the mapping from buffer data var to buffer for later lookup
61+
for (auto buffer : op->alloc_buffers) {
62+
buffer_map_.insert({buffer->data, buffer});
63+
}
64+
for (auto match_buffer : op->match_buffers) {
65+
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
66+
}
67+
for (auto buffer : op->alloc_buffers) {
68+
buffer_data_to_buffer_.Set(buffer->data, buffer);
69+
}
70+
71+
if (op->annotations.count(attr::kL2RatioMap)) {
72+
auto hit_ratio_map = op->annotations.at(attr::kL2RatioMap)
73+
.as<Map<Var, FloatImm>>()
74+
.value();
75+
for (auto [buffer_var, hit_ratio] : hit_ratio_map) {
76+
Buffer buffer = buffer_data_to_buffer_.at(buffer_var);
77+
hit_ratio_map_.Set(buffer, hit_ratio);
78+
}
79+
}
80+
auto block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
81+
auto block_ptr = block.CopyOnWrite();
82+
block_ptr->annotations.erase(attr::kL2RatioMap);
83+
return block;
84+
}
85+
86+
private:
87+
// Mapping from data Var of a Buffer to Buffer, for lookup
88+
Map<Var, Buffer> buffer_data_to_buffer_;
89+
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
90+
Map<Buffer, FloatImm> hit_ratio_map_;
91+
LowerL2Persistent() = default;
92+
};
93+
94+
using namespace tir::transform;
95+
96+
tvm::transform::Pass LowerL2Persistent() {
97+
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
98+
return LowerL2Persistent::Substitute(f);
99+
};
100+
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
101+
}
102+
103+
TVM_REGISTER_GLOBAL("tl.transform.LowerL2Persistent")
104+
.set_body_typed(LowerL2Persistent);
105+
106+
} // namespace tl
107+
} // namespace tvm

tilelang/carver/arch/driver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
get_shared_memory_per_block, # noqa: F401
88
get_device_attribute, # noqa: F401
99
get_max_dynamic_shared_size_bytes, # noqa: F401
10+
get_persisting_l2_cache_max_size, # noqa: F401
1011
get_num_sms, # noqa: F401
1112
)

tilelang/carver/arch/driver/cuda_driver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
167167
raise RuntimeError("Failed to get device properties.")
168168

169169

170+
def get_persisting_l2_cache_max_size(device_id: int = 0) -> int:
171+
prop = get_cuda_device_properties(device_id)
172+
if prop:
173+
return prop.persistingL2CacheMaxSize
174+
else:
175+
raise RuntimeError("Failed to get device properties for persisting L2 cache max size.")
176+
177+
170178
def get_num_sms(device_id: int = 0) -> int:
171179
"""
172180
Get the number of streaming multiprocessors (SMs) on the CUDA device.

tilelang/engine/phase.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
6060
mod = tilelang.transform.LayoutInference()(mod)
6161
# Lower high-level tile operations to low-level operations
6262
mod = tilelang.transform.LowerTileOp()(mod)
63+
# Lower l2 persistent map
64+
mod = tilelang.transform.LowerL2Persistent()(mod)
6365
# Legalize vectorized loops to ensure they are valid
6466
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
6567
# Add safety checks for memory accesses

tilelang/jit/adapter/libgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def compile_lib(self, timeout: float = None):
9898

9999
src.write(self.lib_code)
100100
src.flush()
101+
101102
try:
102103
ret = subprocess.run(command, timeout=timeout)
103104
except Exception as e:

tilelang/jit/adapter/wrapper.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@
4949
}}
5050
"""
5151

52+
L2_PERSISTENT_MAP_CREATE_HANDLE = """
53+
\tcudaStreamAttrValue stream_attribute;
54+
\tsize_t init_persisting_l2_cache_size;
55+
\tcudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize);
56+
"""
57+
58+
L2_PERSISTENT_MAP_INIT_FUNC = """
59+
\tstream_attribute.accessPolicyWindow.hitRatio = {1};
60+
\tstream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
61+
\tstream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
62+
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {3});
63+
\tstream_attribute.accessPolicyWindow.base_ptr = (void*)({0});
64+
\tstream_attribute.accessPolicyWindow.num_bytes = {3};
65+
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
66+
"""
67+
68+
L2_PERSISTENT_MAP_RESET_HANDLE = """
69+
\tstream_attribute.accessPolicyWindow.num_bytes = 0;
70+
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
71+
\tcudaCtxResetPersistingL2Cache();
72+
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size);
73+
"""
74+
5275
TMA_DESC_INIT_FUNC = """
5376
\tCUtensorMap {0};
5477
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
@@ -127,6 +150,7 @@ def __init__(self,
127150
self.block_info: Union[List[int], Dict] = [1, 1, 1]
128151
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
129152
self.tma_descriptor_args: Optional[Dict] = None
153+
self.l2_persistent_map: Optional[Dict[str, Dict]] = {}
130154
self.parse_source_information()
131155
self.srcpath: Optional[str] = None
132156
self.libpath: Optional[str] = None
@@ -196,7 +220,15 @@ def legalize_c(p):
196220
p = int(p)
197221
return str(p).replace("//", "/")
198222

223+
has_l2_persistent_map = False
224+
for function_name, _ in function_informations.items():
225+
if function_name in self.l2_persistent_map:
226+
has_l2_persistent_map = True
227+
break
228+
199229
kernel_launch_code = """"""
230+
if has_l2_persistent_map:
231+
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
200232
desc_name_map: Dict[str, str] = {}
201233
for function_name, function_info in function_informations.items():
202234
block_info = function_info["block_info"]
@@ -221,16 +253,37 @@ def legalize_c(p):
221253
grid_str = "dim3({}, {}, {})".format(
222254
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
223255
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
256+
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
257+
kernel_launch_code += init_l2_persistent_map
224258
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
225259
function_name, grid_str, block_str, smem_str, call_args)
226260
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
261+
if has_l2_persistent_map:
262+
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
227263

228-
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
264+
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map)
265+
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
229266

230267
# Wrap the kernel dispatch logic in an external C function
231268
host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
232269
return host_func
233270

271+
def generate_l2_persistent_map(self, function_name: str) -> str:
272+
if function_name not in self.l2_persistent_map:
273+
return ""
274+
init_l2_persistent_map = ""
275+
for buffer_name, (hit_ratio,
276+
size_in_bytes) in self.l2_persistent_map[function_name].items():
277+
# get persisting_l2_cache_max_size
278+
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
279+
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
280+
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
281+
282+
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
283+
buffer_name, float(hit_ratio), size_in_bytes, num_bytes)
284+
285+
return init_l2_persistent_map
286+
234287
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
235288
tma_descripter_init = ""
236289
if self.tma_descriptor_args is None:
@@ -263,10 +316,19 @@ def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
263316
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
264317
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
265318

266-
global_dim = [str(i) for i in global_dim]
267-
global_stride = [str(i) for i in global_stride]
268-
box_dim = [str(i) for i in box_dim]
269-
element_strides = [str(i) for i in element_strides]
319+
def legalize_c2s(p):
320+
# Convert TIR expressions to legal C expressions
321+
# Directly convert to string since the special case handling
322+
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
323+
# Replace Python's floor division operator with C's division operator
324+
if isinstance(p, tvm.tir.IntImm):
325+
p = int(p)
326+
return str(p)
327+
328+
global_dim = [legalize_c2s(i) for i in global_dim]
329+
global_stride = [legalize_c2s(i) for i in global_stride]
330+
box_dim = [legalize_c2s(i) for i in box_dim]
331+
element_strides = [legalize_c2s(i) for i in element_strides]
270332

271333
# Extract remaining parameters
272334
try:
@@ -331,6 +393,9 @@ def parse_source_information(self):
331393
for _, func in self.host_mod.functions.items():
332394
if "tma_descriptor_args" in func.attrs:
333395
self.tma_descriptor_args = func.attrs["tma_descriptor_args"]
396+
if "l2_persistent_map" in func.attrs:
397+
self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"]
398+
334399
host_code = str(func)
335400
for function_name in function_names:
336401
index = host_code.index(f'T.call_packed("{function_name}"')

tilelang/language/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,28 @@ def main(
147147
_padding_map = {}
148148
for buffer, padding_value in padding_map.items():
149149
# assert not global
150-
assert buffer.scope() != "global", "padding can only be applied to global buffers"
150+
assert buffer.scope() != "global", "padding can not be applied to global buffers"
151151
_padding_map[buffer.data] = padding_value
152152
return block_attr({"padding_map": _padding_map})
153153

154154

155+
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
156+
"""Annotate the L2 hit ratio of the buffer, detailed explanation please refer to:
157+
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#l2-policy-for-persisting-accesses
158+
159+
Args:
160+
l2_hit_ratio_map (dict): a dictionary of buffer to L2 hit ratio value
161+
Example:
162+
# 0.5 is the hit ratio
163+
T.annotate_l2_hit_ratio({A: 0.5})
164+
"""
165+
_l2_hit_ratio_map = {}
166+
for buffer, hit_ratio in l2_hit_ratio_map.items():
167+
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
168+
_l2_hit_ratio_map[buffer.data] = hit_ratio
169+
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
170+
171+
155172
def import_source(source: Optional[str] = None):
156173
# source is the source code to be imported
157174
return block_attr({"pragma_import_c": source}) if source is not None else None

tilelang/transform/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,9 @@ def MergeSharedMemoryAllocations():
344344
The result pass
345345
"""
346346
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
347+
348+
349+
def LowerL2Persistent():
350+
"""LowerL2Persistent
351+
"""
352+
return _ffi_api.LowerL2Persistent() # type: ignore

0 commit comments

Comments
 (0)