diff --git a/testing/python/issue/test_tilelang_issue_1198.py b/testing/python/issue/test_tilelang_issue_1198.py new file mode 100644 index 000000000..eb9ed4596 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1198.py @@ -0,0 +1,15 @@ +import tilelang.testing +import tilelang.language as T + + +def test_issue_1198(): + + @T.prim_func + def foo(x: T.Buffer([ + 32, + ], "int32")): + pass + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index da696517f..a3f2482d2 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -8,7 +8,7 @@ from tvm import DataType, tir from tvm.runtime import convert from typing import Any -from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad +from tvm.tir import PrimExpr, Var, Call, BufferLoad _IS_HIP_AVAILABLE = check_hip_availability() @@ -430,7 +430,7 @@ def shuffle_elect(thread_extent: int) -> PrimExpr: return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) -def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr, +def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None): @@ -456,7 +456,7 @@ def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr, if isinstance(buffer_or_ptr, BufferLoad): raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.") - if isinstance(buffer_or_ptr, Buffer): + if isinstance(buffer_or_ptr, tir.Buffer): data_ptr = buffer_or_ptr.data inferred_dtype = buffer_or_ptr.dtype if dtype is not None and dtype != inferred_dtype: @@ -599,7 +599,7 @@ def sync_grid(): def initialize_wgmma_descriptor( - descriptor: Buffer, + descriptor: tir.Buffer, start_address: PrimExpr, layout_type_: int = 0, leading_byte_offset: int = 0, @@ -607,10 +607,11 @@ def initialize_wgmma_descriptor( ) -> PrimExpr: """Initialize a WGMMA/UTCMMA shared-memory descriptor.""" - if not isinstance(descriptor, (BufferLoad, Buffer)): + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or + descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( @@ -629,7 +630,7 @@ def initialize_wgmma_descriptor( def initialize_tcgen05_descriptor( - descriptor: Buffer, + descriptor: tir.Buffer, start_address: PrimExpr, leading_byte_offset: int, stride_byte_offset: int, @@ -639,10 +640,11 @@ def initialize_tcgen05_descriptor( ) -> PrimExpr: """Initialize a TCGEN05 shared-memory descriptor.""" - if not isinstance(descriptor, (BufferLoad, Buffer)): + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or + descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( @@ -673,10 +675,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx Returns: PrimExpr: A handle representing the modified descriptor. """ - if not isinstance(descriptor, (BufferLoad, Buffer)): + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, tir.Buffer) and len( + descriptor.shape) != 1 or descriptor.shape[0] != 1: raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(