From ca8e74da9948e4e5bce5bfc3c6334be0a6a840c5 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 10 Nov 2025 11:39:15 +0800 Subject: [PATCH 1/2] Fix Buffer re-import typo in tilelang.langugage --- testing/python/issue/test_tilelang_issue_1198.py | 10 ++++++++++ tilelang/language/__init__.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 testing/python/issue/test_tilelang_issue_1198.py diff --git a/testing/python/issue/test_tilelang_issue_1198.py b/testing/python/issue/test_tilelang_issue_1198.py new file mode 100644 index 000000000..e4f4c412f --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1198.py @@ -0,0 +1,10 @@ +import tilelang.testing +import tilelang.language as T + +def test_issue_1198(): + @T.prim_func + def foo(x: T.Buffer([32,], "int32")): + pass + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 43c721bbb..ab6ad8fb4 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -82,7 +82,7 @@ ) from .logical import any_of, all_of # noqa: F401 from .builtin import * # noqa: F401 - +from .proxy import Buffer # noqa: F401 from .utils import index_to_coordinates # noqa: F401 from .symbolics import dynamic, symbolic # noqa: F401 From 32bdb070e6c1cb7d3a5cc19693de72b01a75b2a6 Mon Sep 17 00:00:00 2001 From: Freebase6912 <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:55:48 +0800 Subject: [PATCH 2/2] fix lint error --- .../python/issue/test_tilelang_issue_1198.py | 7 +++++- tilelang/language/__init__.py | 2 +- tilelang/language/builtin.py | 25 +++++++++++-------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/testing/python/issue/test_tilelang_issue_1198.py b/testing/python/issue/test_tilelang_issue_1198.py index e4f4c412f..eb9ed4596 100644 --- a/testing/python/issue/test_tilelang_issue_1198.py +++ b/testing/python/issue/test_tilelang_issue_1198.py @@ -1,10 +1,15 @@ import tilelang.testing import tilelang.language as T + def test_issue_1198(): + @T.prim_func - def foo(x: T.Buffer([32,], "int32")): + def foo(x: T.Buffer([ + 32, + ], "int32")): pass + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index ab6ad8fb4..43c721bbb 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -82,7 +82,7 @@ ) from .logical import any_of, all_of # noqa: F401 from .builtin import * # noqa: F401 -from .proxy import Buffer # noqa: F401 + from .utils import index_to_coordinates # noqa: F401 from .symbolics import dynamic, symbolic # noqa: F401 diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index da696517f..a3f2482d2 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -8,7 +8,7 @@ from tvm import DataType, tir from tvm.runtime import convert from typing import Any -from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad +from tvm.tir import PrimExpr, Var, Call, BufferLoad _IS_HIP_AVAILABLE = check_hip_availability() @@ -430,7 +430,7 @@ def shuffle_elect(thread_extent: int) -> PrimExpr: return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) -def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr, +def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None): @@ -456,7 +456,7 @@ def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr, if isinstance(buffer_or_ptr, BufferLoad): raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.") - if isinstance(buffer_or_ptr, Buffer): + if isinstance(buffer_or_ptr, tir.Buffer): data_ptr = buffer_or_ptr.data inferred_dtype = buffer_or_ptr.dtype if dtype is not None and dtype != inferred_dtype: @@ -599,7 +599,7 @@ def sync_grid(): def initialize_wgmma_descriptor( - descriptor: Buffer, + descriptor: tir.Buffer, start_address: PrimExpr, layout_type_: int = 0, leading_byte_offset: int = 0, @@ -607,10 +607,11 @@ def initialize_wgmma_descriptor( ) -> PrimExpr: """Initialize a WGMMA/UTCMMA shared-memory descriptor.""" - if not isinstance(descriptor, (BufferLoad, Buffer)): + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or + descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( @@ -629,7 +630,7 @@ def initialize_wgmma_descriptor( def initialize_tcgen05_descriptor( - descriptor: Buffer, + descriptor: tir.Buffer, start_address: PrimExpr, leading_byte_offset: int, stride_byte_offset: int, @@ -639,10 +640,11 @@ def initialize_tcgen05_descriptor( ) -> PrimExpr: """Initialize a TCGEN05 shared-memory descriptor.""" - if not isinstance(descriptor, (BufferLoad, Buffer)): + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or + descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( @@ -673,10 +675,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx Returns: PrimExpr: A handle representing the modified descriptor. """ - if not isinstance(descriptor, (BufferLoad, Buffer)): + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, tir.Buffer) and len( + descriptor.shape) != 1 or descriptor.shape[0] != 1: raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(