-
Notifications
You must be signed in to change notification settings - Fork 333
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_buggy_kernel(hidden):
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden, ), dtype='float')
T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1)
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel(128)
print(kernel.get_kernel_source())
x = torch.zeros((1, 128), dtype=torch.float, device='cuda')
kernel(x)Error:
tilelang/tilelang/../3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
tvm.error.InternalError: Check failed: node->dim < static_cast<int>(node->src->shape.size()) (1 vs. 1) :
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working