Skip to content

Commit 92417ed

Browse files
committed
lint fix
1 parent 90fc633 commit 92417ed

File tree

3 files changed

+147
-19
lines changed

3 files changed

+147
-19
lines changed

testing/python/language/test_tilelang_language_get_warp_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,5 @@ def test_shuffle_elect_block_leader():
208208

209209

210210
if __name__ == "__main__":
211-
tilelang.testing.main()
212-
# run_get_lane_id()
211+
# tilelang.testing.main()
212+
test_get_lane_idx_custom()

tilelang/language/builtin.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,140 @@ def warpgroup_wait(num_mma: int):
296296
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
297297

298298

299+
def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
300+
"""Return the logical lane index of the calling thread within a warp.
301+
302+
Parameters
303+
----------
304+
warp_size : Optional[int, PrimExpr]
305+
Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.
306+
307+
Example
308+
-------
309+
>>> lane = T.get_lane_idx()
310+
>>> custom_lane = T.get_lane_idx(64) # override warp size explicitly
311+
312+
Implementation Notes
313+
--------------------
314+
Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in
315+
`src/tl_templates/cuda/intrin.h`, which computes the lane index from the
316+
linear thread id using the provided `warp_size`.
317+
"""
318+
warp_size_expr = _normalize_index_arg(warp_size)
319+
if warp_size_expr is None:
320+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"))
321+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
322+
323+
324+
def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
325+
"""Return the canonical warp index, assuming the warp's threads are converged.
326+
327+
Parameters
328+
----------
329+
warp_size : Optional[int, PrimExpr]
330+
Logical warp size used for the index calculation.
331+
332+
Example
333+
-------
334+
>>> warp = T.get_warp_idx_sync()
335+
>>> custom_warp = T.get_warp_idx_sync(64)
336+
337+
Implementation Notes
338+
--------------------
339+
Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear
340+
thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers.
341+
"""
342+
warp_size_expr = _normalize_index_arg(warp_size)
343+
if warp_size_expr is None:
344+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"))
345+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
346+
347+
348+
def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
349+
"""Return the canonical warp index without synchronizing the warp.
350+
351+
Parameters
352+
----------
353+
warp_size : Optional[int, PrimExpr]
354+
Logical warp size used for the index calculation.
355+
356+
Example
357+
-------
358+
>>> warp = T.get_warp_idx()
359+
>>> custom_warp = T.get_warp_idx(64)
360+
361+
Implementation Notes
362+
--------------------
363+
Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear
364+
thread id by the provided `warp_size` without requiring warp convergence.
365+
"""
366+
warp_size_expr = _normalize_index_arg(warp_size)
367+
if warp_size_expr is None:
368+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"))
369+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"), warp_size_expr)
370+
371+
372+
def get_warp_group_idx(
373+
warp_size: int | PrimExpr | None = None,
374+
warps_per_group: int | PrimExpr | None = None,
375+
) -> PrimExpr:
376+
"""Return the canonical warp group index for the calling thread.
377+
378+
Parameters
379+
----------
380+
warp_size : Optional[int, PrimExpr]
381+
Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).
382+
warps_per_group : Optional[int, PrimExpr]
383+
Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.
384+
385+
Example
386+
-------
387+
>>> group = T.get_warp_group_idx()
388+
>>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group
389+
390+
Implementation Notes
391+
--------------------
392+
Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which
393+
divides the block-linear thread id by `warp_size * warps_per_group`,
394+
matching the canonical ordering while allowing architecture-specific overrides.
395+
"""
396+
warp_size_expr = _normalize_index_arg(warp_size)
397+
warps_per_group_expr = _normalize_index_arg(warps_per_group)
398+
args = []
399+
if warp_size_expr is not None:
400+
args.append(warp_size_expr)
401+
if warps_per_group_expr is not None:
402+
if warp_size_expr is None:
403+
raise ValueError("get_warp_group_idx expects `warp_size` when specifying "
404+
"`warps_per_group`.")
405+
args.append(warps_per_group_expr)
406+
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args)
407+
408+
409+
def shuffle_elect(thread_extent: int) -> PrimExpr:
410+
"""Elect exactly one lane within a logical thread group.
411+
412+
Parameters
413+
----------
414+
thread_extent : int
415+
Size (in threads) of the group in which a single lane should be elected.
416+
Passing 0 elects a single lane in the entire thread block.
417+
418+
Example
419+
-------
420+
>>> is_leader = T.shuffle_elect(64)
421+
>>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))
422+
423+
Implementation Notes
424+
--------------------
425+
Lowered to the CUDA helper `tl::tl_shuffle_elect<thread_extent>()` defined in
426+
`src/tl_templates/cuda/intrin.h`, which relies on
427+
`cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or
428+
`__shfl_sync`) to pick one lane per group.
429+
"""
430+
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
431+
432+
299433
def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr,
300434
offset: int | PrimExpr = 0,
301435
num_regs: int | PrimExpr | None = None,
@@ -563,3 +697,14 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
563697
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
564698
"""
565699
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
700+
701+
702+
def tcgen05_mma_arrive(mbar_ptr):
703+
"""Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
704+
705+
Parameters
706+
----------
707+
mbar_ptr : PrimExpr
708+
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
709+
"""
710+
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)

tilelang/language/tir/ir.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from tvm.tir import PrimExpr
55
from typing import Any
66
import tilelang.language.tir.op as _tir_op
7-
import tvm.tir.op as _tvm_tir_op
87
import functools
98

109

@@ -309,19 +308,3 @@ def wrapped(*args, **kwargs):
309308
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
310309
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
311310
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)
312-
313-
314-
# Convenience wrapper for TL shuffle elect; returns a boolean PrimExpr
315-
def tl_shuffle_elect(thread_extent: PrimExpr | int = 0):
316-
return _tir_op.call_intrin("bool", _tvm_tir_op.Op.get("tl.tl_shuffle_elect"), thread_extent)
317-
318-
319-
def tcgen05_mma_arrive(mbar_ptr):
320-
"""Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
321-
322-
Parameters
323-
----------
324-
mbar_ptr : PrimExpr
325-
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
326-
"""
327-
return call_intrin("void", _tvm_tir_op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)

0 commit comments

Comments
 (0)