Skip to content

Commit 32bdb07

Browse files
committed
fix lint error
1 parent ca8e74d commit 32bdb07

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import tilelang.testing
22
import tilelang.language as T
33

4+
45
def test_issue_1198():
6+
57
@T.prim_func
6-
def foo(x: T.Buffer([32,], "int32")):
8+
def foo(x: T.Buffer([
9+
32,
10+
], "int32")):
711
pass
812

13+
914
if __name__ == '__main__':
1015
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
)
8383
from .logical import any_of, all_of # noqa: F401
8484
from .builtin import * # noqa: F401
85-
from .proxy import Buffer # noqa: F401
85+
8686
from .utils import index_to_coordinates # noqa: F401
8787

8888
from .symbolics import dynamic, symbolic # noqa: F401

tilelang/language/builtin.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tvm import DataType, tir
99
from tvm.runtime import convert
1010
from typing import Any
11-
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
11+
from tvm.tir import PrimExpr, Var, Call, BufferLoad
1212

1313
_IS_HIP_AVAILABLE = check_hip_availability()
1414

@@ -430,7 +430,7 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
430430
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
431431

432432

433-
def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr,
433+
def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
434434
offset: int | PrimExpr = 0,
435435
num_regs: int | PrimExpr | None = None,
436436
dtype: str | None = None):
@@ -456,7 +456,7 @@ def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr,
456456
if isinstance(buffer_or_ptr, BufferLoad):
457457
raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.")
458458

459-
if isinstance(buffer_or_ptr, Buffer):
459+
if isinstance(buffer_or_ptr, tir.Buffer):
460460
data_ptr = buffer_or_ptr.data
461461
inferred_dtype = buffer_or_ptr.dtype
462462
if dtype is not None and dtype != inferred_dtype:
@@ -599,18 +599,19 @@ def sync_grid():
599599

600600

601601
def initialize_wgmma_descriptor(
602-
descriptor: Buffer,
602+
descriptor: tir.Buffer,
603603
start_address: PrimExpr,
604604
layout_type_: int = 0,
605605
leading_byte_offset: int = 0,
606606
stride_byte_offset: int = 0,
607607
) -> PrimExpr:
608608
"""Initialize a WGMMA/UTCMMA shared-memory descriptor."""
609609

610-
if not isinstance(descriptor, (BufferLoad, Buffer)):
610+
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
611611
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
612612

613-
if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
613+
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
614+
descriptor.shape[0] != 1):
614615
raise ValueError("Descriptor must be a 1D buffer of size 1.")
615616

616617
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
@@ -629,7 +630,7 @@ def initialize_wgmma_descriptor(
629630

630631

631632
def initialize_tcgen05_descriptor(
632-
descriptor: Buffer,
633+
descriptor: tir.Buffer,
633634
start_address: PrimExpr,
634635
leading_byte_offset: int,
635636
stride_byte_offset: int,
@@ -639,10 +640,11 @@ def initialize_tcgen05_descriptor(
639640
) -> PrimExpr:
640641
"""Initialize a TCGEN05 shared-memory descriptor."""
641642

642-
if not isinstance(descriptor, (BufferLoad, Buffer)):
643+
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
643644
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
644645

645-
if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
646+
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
647+
descriptor.shape[0] != 1):
646648
raise ValueError("Descriptor must be a 1D buffer of size 1.")
647649

648650
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
@@ -673,10 +675,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
673675
Returns:
674676
PrimExpr: A handle representing the modified descriptor.
675677
"""
676-
if not isinstance(descriptor, (BufferLoad, Buffer)):
678+
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
677679
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
678680

679-
if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
681+
if isinstance(descriptor, tir.Buffer) and len(
682+
descriptor.shape) != 1 or descriptor.shape[0] != 1:
680683
raise ValueError("Descriptor must be a 1D buffer of size 1.")
681684

682685
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(

0 commit comments

Comments
 (0)