-
Notifications
You must be signed in to change notification settings - Fork 333
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2+cuda.git551ac60d
System information
3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post2+cuda.git551ac60d
2.7.0+cu128
Problem description
Global Store buffer is assumed to be legal without any check. That behavior is different from previous versions.
Reproducible example code
The Python snippets:
import tilelang
from tilelang import language as T
@tilelang.jit
def get_sample_kernel():
num_threads = 256
size_0 = T.symbolic('size_0')
size_1 = T.symbolic('size_1')
@T.prim_func
def sample_kernel(
num_blocks: T.int32,
idx_out: T.Tensor[(size_0, size_1), "int32"],
):
with T.Kernel(num_blocks, threads=num_threads) as block_idx:
idx_out[block_idx, block_idx] = 0
return sample_kernel
kernel = get_sample_kernel()
print(kernel.get_kernel_source())Generated Code
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif
extern "C" __global__ void sample_kernel_kernel(int* __restrict__ idx_out, int num_blocks, int size_0, int size_1);
extern "C" __global__ void __launch_bounds__(256, 1) sample_kernel_kernel(int* __restrict__ idx_out, int num_blocks, int size_0, int size_1) {
idx_out[((((int64_t)size_1) + (int64_t)1) * ((int64_t)((int)blockIdx.x)))] = 0;
}Expected behavior
Maybe gen code should use if to guard the global store? Or is it a default behavior from v2?
Additional context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working