88from tvm import DataType , tir
99from tvm .runtime import convert
1010from typing import Any
11- from tvm .tir import PrimExpr , Var , Call , Buffer , BufferLoad
11+ from tvm .tir import PrimExpr , Var , Call , BufferLoad
1212
1313_IS_HIP_AVAILABLE = check_hip_availability ()
1414
@@ -430,7 +430,7 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
430430 return tir .call_intrin ("bool" , tir .op .Op .get ("tl.tl_shuffle_elect" ), thread_extent )
431431
432432
433- def warpgroup_fence_operand (buffer_or_ptr : Buffer | PrimExpr ,
433+ def warpgroup_fence_operand (buffer_or_ptr : tir . Buffer | PrimExpr ,
434434 offset : int | PrimExpr = 0 ,
435435 num_regs : int | PrimExpr | None = None ,
436436 dtype : str | None = None ):
@@ -456,7 +456,7 @@ def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr,
456456 if isinstance (buffer_or_ptr , BufferLoad ):
457457 raise TypeError ("Expected a buffer handle or pointer expression, got BufferLoad." )
458458
459- if isinstance (buffer_or_ptr , Buffer ):
459+ if isinstance (buffer_or_ptr , tir . Buffer ):
460460 data_ptr = buffer_or_ptr .data
461461 inferred_dtype = buffer_or_ptr .dtype
462462 if dtype is not None and dtype != inferred_dtype :
@@ -599,18 +599,19 @@ def sync_grid():
599599
600600
601601def initialize_wgmma_descriptor (
602- descriptor : Buffer ,
602+ descriptor : tir . Buffer ,
603603 start_address : PrimExpr ,
604604 layout_type_ : int = 0 ,
605605 leading_byte_offset : int = 0 ,
606606 stride_byte_offset : int = 0 ,
607607) -> PrimExpr :
608608 """Initialize a WGMMA/UTCMMA shared-memory descriptor."""
609609
610- if not isinstance (descriptor , (BufferLoad , Buffer )):
610+ if not isinstance (descriptor , (BufferLoad , tir . Buffer )):
611611 raise TypeError ("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad." )
612612
613- if isinstance (descriptor , Buffer ) and (len (descriptor .shape ) != 1 or descriptor .shape [0 ] != 1 ):
613+ if isinstance (descriptor , tir .Buffer ) and (len (descriptor .shape ) != 1 or
614+ descriptor .shape [0 ] != 1 ):
614615 raise ValueError ("Descriptor must be a 1D buffer of size 1." )
615616
616617 descriptor = descriptor if isinstance (descriptor , BufferLoad ) else tir .BufferLoad (
@@ -629,7 +630,7 @@ def initialize_wgmma_descriptor(
629630
630631
631632def initialize_tcgen05_descriptor (
632- descriptor : Buffer ,
633+ descriptor : tir . Buffer ,
633634 start_address : PrimExpr ,
634635 leading_byte_offset : int ,
635636 stride_byte_offset : int ,
@@ -639,10 +640,11 @@ def initialize_tcgen05_descriptor(
639640) -> PrimExpr :
640641 """Initialize a TCGEN05 shared-memory descriptor."""
641642
642- if not isinstance (descriptor , (BufferLoad , Buffer )):
643+ if not isinstance (descriptor , (BufferLoad , tir . Buffer )):
643644 raise TypeError ("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad." )
644645
645- if isinstance (descriptor , Buffer ) and (len (descriptor .shape ) != 1 or descriptor .shape [0 ] != 1 ):
646+ if isinstance (descriptor , tir .Buffer ) and (len (descriptor .shape ) != 1 or
647+ descriptor .shape [0 ] != 1 ):
646648 raise ValueError ("Descriptor must be a 1D buffer of size 1." )
647649
648650 descriptor = descriptor if isinstance (descriptor , BufferLoad ) else tir .BufferLoad (
@@ -673,10 +675,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
673675 Returns:
674676 PrimExpr: A handle representing the modified descriptor.
675677 """
676- if not isinstance (descriptor , (BufferLoad , Buffer )):
678+ if not isinstance (descriptor , (BufferLoad , tir . Buffer )):
677679 raise TypeError ("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad." )
678680
679- if isinstance (descriptor , Buffer ) and len (descriptor .shape ) != 1 or descriptor .shape [0 ] != 1 :
681+ if isinstance (descriptor , tir .Buffer ) and len (
682+ descriptor .shape ) != 1 or descriptor .shape [0 ] != 1 :
680683 raise ValueError ("Descriptor must be a 1D buffer of size 1." )
681684
682685 descriptor = descriptor if isinstance (descriptor , BufferLoad ) else tir .BufferLoad (
0 commit comments