Skip to content

Commit 0e63caa

Browse files
authored
[FFI] Use tvm ffi as the default execution backend (tile-ai#1259)
* [Refactor] Update FFI type handling and simplify argument management * Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity. * Updated function registration in `runtime.cc` to utilize canonical names for better consistency. * Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled. * Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection. * Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity. * [Update] Sync TVM submodule and enhance kernel source handling * Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes. * Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging. * Commented out the main execution call in test files to prevent unintended execution during testing. * Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues. * Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends. * [Refactor] Clean up imports and improve code formatting * Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code. * Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency. * Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality. * Update execution backend options and improve resolution logic - Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target. - Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions. - Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target. - Updated documentation to reflect changes in execution backend options and their defaults. * lint fix * fix * Enhance argument handling in CUDA and HIP runtime modules - Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime. - Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers. - Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks. * lint fix * lint fix * lint fix * lint fix * minor fix * fix * recover check * Refactor argument binding and validation in `arg_binder.cc` - Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers. - Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards. - Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling. - Minor adjustments in test files to streamline kernel execution and improve readability. * lint fix * stride fix * minor fix * fix * lint fix * lint fix * Add CUDA stream access policy window helpers and integrate with L2 persistent cache management - Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage. - Updated runtime files to include new FFI packed functions for managing stream attributes. - Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown. - Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source. * check with symbolic * support null ptr * Update CMakeLists and lower.py for code generation and subproject status - Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support. - Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility. - Marked the TVM subproject as dirty to indicate local modifications. * lint fix * Update comments for clarity in quickstart.py
1 parent e3f167f commit 0e63caa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2721
-928
lines changed

3rdparty/tvm

Submodule tvm updated from 093b2cd to f4105f8

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ file(GLOB TILE_LANG_SRCS
138138
src/transform/*.cc
139139
src/op/*.cc
140140
src/target/utils.cc
141+
src/target/codegen_c_host.cc
141142
src/target/codegen_cpp.cc
142143
src/target/rt_mod_cpp.cc
143144
# intrin_rule doesn't have system dependency

examples/blocksparse_gemm/example_blocksparse_gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def main():
166166
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
167167
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
168168
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
169-
170169
# Create block mask with desired sparsity
171170
mask_shape = (M // block_M, N // block_N, K // block_K)
172171
block_mask = torch.rand(mask_shape).cuda() > sparsity

examples/gdn/example_chunk_o_bwd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,6 @@ def run_test(
468468
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
469469
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw,
470470
block_DK, block_DV, threads, num_stages)
471-
print(kernel.get_kernel_source())
472471
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
473472

474473
if use_g:

examples/gdn/test_example_gdn_compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_example_chunk_o_bwd_compilation():
117117
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
118118
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
119119
block_DK, block_DV, threads, num_stages)
120+
120121
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
121122
W) # noqa: F841
122123
if use_g:

examples/quickstart.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ def matmul_relu_kernel(
5555
block_N = 128
5656
block_K = 32
5757

58-
# 1. Define the kernel (matmul) and compile/lower it into an executable module
58+
# Define the kernel (matmul) and compile/lower it into an executable module
5959
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
60-
61-
# 3. Test the kernel in Python with PyTorch data
60+
# Test the kernel in Python with PyTorch data
6261
import torch
6362

6463
# Create random input tensors on the GPU

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ tilelang = "tilelang"
104104
# TVM
105105
"tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src"
106106
"tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python"
107+
"tilelang/3rdparty/tvm/include" = "3rdparty/tvm/include"
107108
"tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py"
108109
# CUTLASS
109110
"tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include"

src/runtime/runtime.cc

Lines changed: 158 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
namespace tvm {
1414
namespace tl {
1515

16+
#if 1
17+
// Thread-local storage for restoring the L2 persisting cache limit
18+
static thread_local size_t __tl_prev_persisting_l2_cache_size = 0;
19+
static thread_local bool __tl_prev_persisting_l2_cache_saved = false;
20+
#endif
21+
1622
#if (CUDA_MAJOR_VERSION >= 12)
1723
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
1824
std::stringstream ss;
@@ -91,19 +97,21 @@ struct TensorMapArgs {
9197
// set device api
9298
TVM_FFI_STATIC_INIT_BLOCK() {
9399
namespace refl = tvm::ffi::reflection;
94-
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
95-
Any *ret) {
96-
TensorMapArgs T = TensorMapArgs::Extract(args);
97-
CUresult result = cuTensorMapEncodeTiled(
98-
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
99-
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle,
100-
T.l2Promotion, T.oobFill);
101-
if (result != CUDA_SUCCESS) {
102-
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n'
103-
<< T.ToDebugString();
104-
}
105-
*ret = static_cast<int>(result);
106-
});
100+
// Register using the canonical names defined in runtime.h
101+
refl::GlobalDef().def_packed(
102+
tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) {
103+
TensorMapArgs T = TensorMapArgs::Extract(args);
104+
CUresult result = cuTensorMapEncodeTiled(
105+
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
106+
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
107+
T.swizzle, T.l2Promotion, T.oobFill);
108+
if (result != CUDA_SUCCESS) {
109+
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
110+
<< '\n'
111+
<< T.ToDebugString();
112+
}
113+
*ret = static_cast<int>(result);
114+
});
107115
}
108116

109117
struct TensorMapIm2ColArgs {
@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs {
183191
TVM_FFI_STATIC_INIT_BLOCK() {
184192
namespace refl = tvm::ffi::reflection;
185193
refl::GlobalDef().def_packed(
186-
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
194+
tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) {
187195
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
188196
CUresult result = cuTensorMapEncodeIm2col(
189197
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() {
201209

202210
#endif // (CUDA_MAJOR_VERSION >= 12)
203211

212+
//
213+
// CUDA L2 Persisting Cache Access Policy Window helpers.
214+
// Exposed as TVM FFI packed functions similar to TMA initialization.
215+
//
216+
TVM_FFI_STATIC_INIT_BLOCK() {
217+
namespace refl = tvm::ffi::reflection;
218+
// Set stream access policy window and adjust persisting L2 cache size
219+
// Args:
220+
// [0]: void* base_ptr (required)
221+
// [1]: int64 num_bytes (required)
222+
// [2]: float hit_ratio (optional, default 0.8)
223+
// [3]: void* stream (optional, default 0 => default stream)
224+
// [4]: int64 l2_limit_bytes (optional, default = num_bytes)
225+
refl::GlobalDef().def_packed(
226+
tl::tvm_cuda_stream_set_access_policy_window,
227+
[](PackedArgs args, Any *ret) {
228+
ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes";
229+
230+
void *base_ptr = args[0].cast<void *>();
231+
size_t num_bytes = static_cast<size_t>(args[1].cast<int64_t>());
232+
float hit_ratio = 0.8f;
233+
if (args.size() >= 3) {
234+
// Accept double/float
235+
hit_ratio = static_cast<float>(args[2].cast<double>());
236+
}
237+
CUstream stream = nullptr;
238+
if (args.size() >= 4) {
239+
stream = reinterpret_cast<CUstream>(args[3].cast<void *>());
240+
}
241+
size_t l2_limit_bytes = num_bytes;
242+
if (args.size() >= 5) {
243+
l2_limit_bytes = static_cast<size_t>(args[4].cast<int64_t>());
244+
}
245+
246+
// Clamp requested limit to device capability
247+
CUdevice device;
248+
CUresult result = cuCtxGetDevice(&device);
249+
if (result != CUDA_SUCCESS) {
250+
LOG_FATAL << "Failed to get current CUDA device: " << result;
251+
}
252+
int max_persisting = 0;
253+
result = cuDeviceGetAttribute(
254+
&max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE,
255+
device);
256+
if (result != CUDA_SUCCESS) {
257+
LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: "
258+
<< result;
259+
}
260+
if (max_persisting > 0 &&
261+
l2_limit_bytes > static_cast<size_t>(max_persisting)) {
262+
l2_limit_bytes = static_cast<size_t>(max_persisting);
263+
}
264+
265+
// Save current limit to restore later
266+
size_t init_persisting_l2_cache_size = 0;
267+
result = cuCtxGetLimit(&init_persisting_l2_cache_size,
268+
CU_LIMIT_PERSISTING_L2_CACHE_SIZE);
269+
if (result != CUDA_SUCCESS) {
270+
LOG_FATAL << "Failed to get current persisting L2 cache size limit: "
271+
<< result;
272+
}
273+
__tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size;
274+
__tl_prev_persisting_l2_cache_saved = true;
275+
276+
// Set new limit
277+
result =
278+
cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes);
279+
if (result != CUDA_SUCCESS) {
280+
LOG_FATAL << "Failed to set persisting L2 cache size limit: "
281+
<< result;
282+
}
283+
284+
// Apply access policy window to stream
285+
CUstreamAttrValue stream_attribute;
286+
memset(&stream_attribute, 0, sizeof(stream_attribute));
287+
stream_attribute.accessPolicyWindow.base_ptr = base_ptr;
288+
stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes;
289+
stream_attribute.accessPolicyWindow.hitRatio = hit_ratio;
290+
stream_attribute.accessPolicyWindow.hitProp =
291+
CU_ACCESS_PROPERTY_PERSISTING;
292+
stream_attribute.accessPolicyWindow.missProp =
293+
CU_ACCESS_PROPERTY_STREAMING;
294+
295+
result = cuStreamSetAttribute(stream,
296+
CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
297+
&stream_attribute);
298+
if (result != CUDA_SUCCESS) {
299+
LOG_FATAL << "Failed to set stream access policy window: " << result;
300+
}
301+
302+
*ret = static_cast<int>(result);
303+
});
304+
305+
// Reset stream access policy window and restore the previous L2 cache size
306+
// Args:
307+
// [0]: void* stream (optional, default 0)
308+
refl::GlobalDef().def_packed(
309+
tl::tvm_cuda_stream_reset_access_policy_window,
310+
[](PackedArgs args, Any *ret) {
311+
CUstream stream = nullptr;
312+
if (args.size() >= 1) {
313+
stream = reinterpret_cast<CUstream>(args[0].cast<void *>());
314+
}
315+
316+
CUstreamAttrValue stream_attribute;
317+
memset(&stream_attribute, 0, sizeof(stream_attribute));
318+
// num_bytes = 0 disables the access policy window on the stream
319+
stream_attribute.accessPolicyWindow.num_bytes = 0;
320+
321+
CUresult result = cuStreamSetAttribute(
322+
stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
323+
&stream_attribute);
324+
if (result != CUDA_SUCCESS) {
325+
LOG_FATAL << "Failed to reset stream access policy window: "
326+
<< result;
327+
}
328+
329+
result = cuCtxResetPersistingL2Cache();
330+
if (result != CUDA_SUCCESS) {
331+
LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result;
332+
}
333+
334+
if (__tl_prev_persisting_l2_cache_saved) {
335+
result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE,
336+
__tl_prev_persisting_l2_cache_size);
337+
if (result != CUDA_SUCCESS) {
338+
LOG_FATAL << "Failed to restore persisting L2 cache size limit: "
339+
<< result;
340+
}
341+
__tl_prev_persisting_l2_cache_saved = false;
342+
}
343+
344+
*ret = static_cast<int>(result);
345+
});
346+
}
347+
204348
} // namespace tl
205349
} // namespace tvm

src/runtime/runtime.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled =
1616
constexpr const char *tvm_tensormap_create_im2col =
1717
"__tvm_tensormap_create_im2col";
1818
#endif // (CUDA_MAJOR_VERSION >= 12)
19+
20+
// CUDA stream access policy window helpers
21+
constexpr const char *tvm_cuda_stream_set_access_policy_window =
22+
"__tvm_cuda_stream_set_access_policy_window";
23+
constexpr const char *tvm_cuda_stream_reset_access_policy_window =
24+
"__tvm_cuda_stream_reset_access_policy_window";
1925
} // namespace tl
2026
} // namespace tvm
2127

22-
#endif // TVM_TL_RUNTIME_RUNTIME_H_
28+
#endif // TVM_TL_RUNTIME_RUNTIME_H_

0 commit comments

Comments
 (0)