Skip to content

[Bug] Failed to infer layout with T.cumsum + view #1001

@LyricZhao

Description

@LyricZhao
import torch
import tilelang
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },
)
def get_buggy_kernel(hidden):
    num_tokens = T.symbolic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
        with T.Kernel(num_tokens, threads=128) as pid:
            smem = T.alloc_shared((hidden, ), dtype='float')
            T.copy(x[pid, :], smem)
            T.cumsum(T.view(smem, (1, hidden)), dim=1)

    return buggy_kernel


if __name__ == '__main__':
    kernel = get_buggy_kernel(128)
    print(kernel.get_kernel_source())

    x = torch.zeros((1, 128), dtype=torch.float, device='cuda')
    kernel(x)

Error:

tilelang/tilelang/../3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
tvm.error.InternalError: Check failed: node->dim < static_cast<int>(node->src->shape.size()) (1 vs. 1) :

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions