Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from c2921f to db50d4
11 changes: 7 additions & 4 deletions src/op/bulk_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
// The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride = desc.global_stride.Map(
[&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return cast(DataType::Int(64), e) * global_tensor->dtype.bytes();
});

// Smem Box
desc.smem_box =
Expand Down Expand Up @@ -325,6 +326,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
desc.data_type = to_CUtensorMapDataType(src->dtype);
desc.global_addr = src->data;
desc.global_shape = ReverseArray(src->shape);

if (!src->strides.empty()) {
desc.global_stride = ReverseArray(src->strides);
} else {
Expand All @@ -339,8 +341,9 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
// The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride = desc.global_stride.Map(
[&](PrimExpr e) { return e * src->dtype.bytes(); });
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return cast(DataType::Int(64), e) * src->dtype.bytes();
});
desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding};
Expand Down
107 changes: 107 additions & 0 deletions src/transform/lower_l2_persistent_annotation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.
/*!
* \file lower_l2_persistent_annotation.cc
* \brief Lower L2 persistent annotation
*/

#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"

namespace tvm {
namespace tl {

namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char *kL2RatioMap = "l2_hit_ratio_map";
constexpr const char *kL2PersistentMap = "l2_persistent_map";
} // namespace attr

using namespace tir;

class LowerL2Persistent : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc &f) {
PrimFuncNode *fptr = f.CopyOnWrite();
LowerL2Persistent substituter;
// Trace the buffer map for tvm_access_ptr
substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_l2_persistent_map;
for (auto [buffer, hit_ratio] : substituter.hit_ratio_map_) {
Array<PrimExpr> l2_persistent_arguments;
// Argument 0: hit ratio
// Argument 1: size in bytes
l2_persistent_arguments.push_back(hit_ratio);
PrimExpr size_in_bytes = IntImm(DataType::Int(64), buffer->dtype.bytes());
for (auto dim : buffer->shape) {
size_in_bytes = size_in_bytes * dim;
}
l2_persistent_arguments.push_back(size_in_bytes);
init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments);
}
if (init_l2_persistent_map.size() > 0) {
f = WithAttr(std::move(f), attr::kL2PersistentMap,
init_l2_persistent_map);
}
return f;
}

Stmt VisitStmt_(const BlockNode *op) final {
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : op->alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
}
for (auto match_buffer : op->match_buffers) {
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
}
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Comment on lines +60 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The way buffer_map_ (an std::unordered_map) and buffer_data_to_buffer_ (a tvm::Map) are populated and used seems a bit complex.

  1. In Substitute, both maps are populated from f->buffer_map.
  2. In VisitStmt_ (this section), buffer_map_ is further populated from op->alloc_buffers and op->match_buffers.
  3. Then, buffer_data_to_buffer_ is populated again from op->alloc_buffers.
  4. The actual lookup for hit_ratio_map uses buffer_data_to_buffer_.at(buffer_var).

Could this logic be simplified? For instance, if buffer_data_to_buffer_ is the primary map for lookups, could its population be consolidated, or is buffer_map_ serving a distinct purpose that's not immediately obvious from its usage here? Clarifying this or streamlining the map management could improve maintainability.


if (op->annotations.count(attr::kL2RatioMap)) {
auto hit_ratio_map = op->annotations.at(attr::kL2RatioMap)
.as<Map<Var, FloatImm>>()
.value();
for (auto [buffer_var, hit_ratio] : hit_ratio_map) {
Buffer buffer = buffer_data_to_buffer_.at(buffer_var);
hit_ratio_map_.Set(buffer, hit_ratio);
}
}
auto block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.erase(attr::kL2RatioMap);
return block;
}

private:
// Mapping from data Var of a Buffer to Buffer, for lookup
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Buffer, FloatImm> hit_ratio_map_;
LowerL2Persistent() = default;
};

using namespace tir::transform;

tvm::transform::Pass LowerL2Persistent() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerL2Persistent::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LowerL2Persistent")
.set_body_typed(LowerL2Persistent);

} // namespace tl
} // namespace tvm
28 changes: 2 additions & 26 deletions src/transform/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ class TLVectorizer : public StmtMutator,
}
}
}

// Allocate
Stmt VisitStmt_(const AllocateNode *op) final {
// Mutate the condition
Expand All @@ -678,32 +679,7 @@ class TLVectorizer : public StmtMutator,
return Scalarize(GetRef<Stmt>(op));
}

// Mutate the extents
Array<PrimExpr> extents;
for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
extents.push_back(new_ext);
}

// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer. That will allow this pass to be
// implemented as adding a new buffer dimension, which is later
// flattened.

// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body.
Stmt body =
TLVecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
return StmtMutator::VisitStmt_(op);
Comment on lines 681 to +682
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling AllocateNode within TLVectorizer::VisitStmt_ has been significantly changed. Previously, there was specific code to mutate extents, extend the least significant dimension by var_lanes_, and rewrite buffer access using TLVecAllocAccess. This has been replaced by return StmtMutator::VisitStmt_(op);.

This is a substantial change. Could you please elaborate on the rationale behind this modification?

  • How is the vectorization of allocations handled now?
  • Is the functionality previously provided by the removed code now achieved elsewhere, or is it no longer considered necessary?
  • What are the implications for correctness and performance of vectorized loops involving allocations?

Without understanding the reasoning, it's hard to assess the impact of this change. The TLVecAllocAccess class is still defined in this file but doesn't seem to be used by the AllocateNode visitor anymore.

}

// scalarize the statment
Expand Down
1 change: 1 addition & 0 deletions tilelang/carver/arch/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
get_shared_memory_per_block, # noqa: F401
get_device_attribute, # noqa: F401
get_max_dynamic_shared_size_bytes, # noqa: F401
get_persisting_l2_cache_max_size, # noqa: F401
get_num_sms, # noqa: F401
)
8 changes: 8 additions & 0 deletions tilelang/carver/arch/driver/cuda_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
raise RuntimeError("Failed to get device properties.")


def get_persisting_l2_cache_max_size(device_id: int = 0) -> int:
prop = get_cuda_device_properties(device_id)
if prop:
return prop.persistingL2CacheMaxSize
else:
raise RuntimeError("Failed to get device properties for persisting L2 cache max size.")


def get_num_sms(device_id: int = 0) -> int:
"""
Get the number of streaming multiprocessors (SMs) on the CUDA device.
Expand Down
3 changes: 2 additions & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod)
# Legalize vectorized loops to ensure they are valid
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses
Expand Down Expand Up @@ -110,7 +112,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.Simplify()(mod)

mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)

mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
Expand Down
4 changes: 3 additions & 1 deletion tilelang/jit/adapter/libgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,15 @@ def compile_lib(self, timeout: float = None):

src.write(self.lib_code)
src.flush()

try:
ret = subprocess.run(command, timeout=timeout)
except Exception as e:
raise RuntimeError(f"Compile kernel failed because of {e}") from e

if ret.returncode != 0:
raise RuntimeError(f"Compilation Failed! {command}")
raise RuntimeError(f"Compilation Failed! {command}"
f"\n {self.lib_code}")

self.srcpath = src.name
self.libpath = libpath
Expand Down
75 changes: 70 additions & 5 deletions tilelang/jit/adapter/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@
}}
"""

L2_PERSISTENT_MAP_CREATE_HANDLE = """
\tcudaStreamAttrValue stream_attribute;
\tsize_t init_persisting_l2_cache_size;
\tcudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize);
"""

L2_PERSISTENT_MAP_INIT_FUNC = """
\tstream_attribute.accessPolicyWindow.hitRatio = {1};
\tstream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
\tstream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {3});
\tstream_attribute.accessPolicyWindow.base_ptr = (void*)({0});
\tstream_attribute.accessPolicyWindow.num_bytes = {3};
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
"""

L2_PERSISTENT_MAP_RESET_HANDLE = """
\tstream_attribute.accessPolicyWindow.num_bytes = 0;
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
\tcudaCtxResetPersistingL2Cache();
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size);
"""

TMA_DESC_INIT_FUNC = """
\tCUtensorMap {0};
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
Expand Down Expand Up @@ -127,6 +150,7 @@ def __init__(self,
self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.tma_descriptor_args: Optional[Dict] = None
self.l2_persistent_map: Optional[Dict[str, Dict]] = {}
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
Expand Down Expand Up @@ -196,7 +220,15 @@ def legalize_c(p):
p = int(p)
return str(p).replace("//", "/")

has_l2_persistent_map = False
for function_name, _ in function_informations.items():
if function_name in self.l2_persistent_map:
has_l2_persistent_map = True
break

kernel_launch_code = """"""
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
Expand All @@ -221,16 +253,37 @@ def legalize_c(p):
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE

kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map)
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code

# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
return host_func

def generate_l2_persistent_map(self, function_name: str) -> str:
if function_name not in self.l2_persistent_map:
return ""
init_l2_persistent_map = ""
for buffer_name, (hit_ratio,
size_in_bytes) in self.l2_persistent_map[function_name].items():
# get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size is performed inside the generate_l2_persistent_map method. According to PEP 8, imports should usually be at the top of the file.

While this might not cause major issues in a code generation context if the method isn't called extremely frequently in a tight loop, it's generally better for clarity, testability (mocking), and avoiding potential repeated import overhead to place imports at the module level.

Could this import be moved to the top of the tilelang/jit/adapter/wrapper.py file?

Suggested change
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
# get persisting_l2_cache_max_size
# Ensure 'get_persisting_l2_cache_max_size' is imported at the top of the file.
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()

persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)

init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
buffer_name, float(hit_ratio), size_in_bytes, num_bytes)

return init_l2_persistent_map

def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
Expand Down Expand Up @@ -263,10 +316,19 @@ def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [str(i) for i in global_dim]
global_stride = [str(i) for i in global_stride]
box_dim = [str(i) for i in box_dim]
element_strides = [str(i) for i in element_strides]
def legalize_c2s(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)

global_dim = [legalize_c2s(i) for i in global_dim]
global_stride = [legalize_c2s(i) for i in global_stride]
box_dim = [legalize_c2s(i) for i in box_dim]
element_strides = [legalize_c2s(i) for i in element_strides]

# Extract remaining parameters
try:
Expand Down Expand Up @@ -331,6 +393,9 @@ def parse_source_information(self):
for _, func in self.host_mod.functions.items():
if "tma_descriptor_args" in func.attrs:
self.tma_descriptor_args = func.attrs["tma_descriptor_args"]
if "l2_persistent_map" in func.attrs:
self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"]
Comment on lines +396 to +397
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in how self.l2_persistent_map is populated. The function_name used as a key on line 397 (self.l2_persistent_map[function_name]) is from the inner loop for function_name in function_names: (defined on line 400). However, func.attrs["l2_persistent_map"] comes from the outer loop's func (from self.host_mod.functions.items()).

If self.host_mod.functions has one item (as asserted), this means func.attrs["l2_persistent_map"] (if it exists on the host function's attributes) would be assigned to self.l2_persistent_map[dev_func_name] for every device function name in function_names. This is likely incorrect.

The l2_persistent_map attribute is added to the device PrimFunc by the LowerL2Persistent C++ pass. Therefore, this information should be retrieved from self.device_mod.functions[g_var].attrs for each device function.

Consider refactoring to populate self.l2_persistent_map by iterating through self.device_mod.functions and accessing attributes from each device function, similar to how block_info_map, grid_info_map, etc., are populated earlier in this method. For example:

# Earlier in parse_source_information, after device_mod is available:
self.l2_persistent_map = {}
for g_var, device_func in self.device_mod.functions.items():
    func_name_hint = g_var.name_hint
    if "l2_persistent_map" in device_func.attrs:
        self.l2_persistent_map[func_name_hint] = device_func.attrs["l2_persistent_map"]

This change would ensure that L2 persistence settings are correctly associated with their respective device functions.


host_code = str(func)
for function_name in function_names:
index = host_code.index(f'T.call_packed("{function_name}"')
Expand Down
19 changes: 18 additions & 1 deletion tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,28 @@ def main(
_padding_map = {}
for buffer, padding_value in padding_map.items():
# assert not global
assert buffer.scope() != "global", "padding can only be applied to global buffers"
assert buffer.scope() != "global", "padding can not be applied to global buffers"
_padding_map[buffer.data] = padding_value
return block_attr({"padding_map": _padding_map})


def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer, detailed explanation please refer to:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#l2-policy-for-persisting-accesses

Args:
l2_hit_ratio_map (dict): a dictionary of buffer to L2 hit ratio value
Example:
# 0.5 is the hit ratio
T.annotate_l2_hit_ratio({A: 0.5})
"""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = hit_ratio
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})


def import_source(source: Optional[str] = None):
# source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None
Loading
Loading