@@ -296,6 +296,140 @@ def warpgroup_wait(num_mma: int):
296296 return tir .call_intrin ("handle" , tir .op .Op .get ("tl.warpgroup_wait" ), num_mma )
297297
298298
299+ def get_lane_idx (warp_size : int | PrimExpr | None = None ,) -> PrimExpr :
300+ """Return the logical lane index of the calling thread within a warp.
301+
302+ Parameters
303+ ----------
304+ warp_size : Optional[int, PrimExpr]
305+ Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.
306+
307+ Example
308+ -------
309+ >>> lane = T.get_lane_idx()
310+ >>> custom_lane = T.get_lane_idx(64) # override warp size explicitly
311+
312+ Implementation Notes
313+ --------------------
314+ Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in
315+ `src/tl_templates/cuda/intrin.h`, which computes the lane index from the
316+ linear thread id using the provided `warp_size`.
317+ """
318+ warp_size_expr = _normalize_index_arg (warp_size )
319+ if warp_size_expr is None :
320+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_lane_idx" ))
321+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_lane_idx" ), warp_size_expr )
322+
323+
324+ def get_warp_idx_sync (warp_size : int | PrimExpr | None = None ,) -> PrimExpr :
325+ """Return the canonical warp index, assuming the warp's threads are converged.
326+
327+ Parameters
328+ ----------
329+ warp_size : Optional[int, PrimExpr]
330+ Logical warp size used for the index calculation.
331+
332+ Example
333+ -------
334+ >>> warp = T.get_warp_idx_sync()
335+ >>> custom_warp = T.get_warp_idx_sync(64)
336+
337+ Implementation Notes
338+ --------------------
339+ Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear
340+ thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers.
341+ """
342+ warp_size_expr = _normalize_index_arg (warp_size )
343+ if warp_size_expr is None :
344+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_warp_idx_sync" ))
345+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_warp_idx_sync" ), warp_size_expr )
346+
347+
348+ def get_warp_idx (warp_size : int | PrimExpr | None = None ,) -> PrimExpr :
349+ """Return the canonical warp index without synchronizing the warp.
350+
351+ Parameters
352+ ----------
353+ warp_size : Optional[int, PrimExpr]
354+ Logical warp size used for the index calculation.
355+
356+ Example
357+ -------
358+ >>> warp = T.get_warp_idx()
359+ >>> custom_warp = T.get_warp_idx(64)
360+
361+ Implementation Notes
362+ --------------------
363+ Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear
364+ thread id by the provided `warp_size` without requiring warp convergence.
365+ """
366+ warp_size_expr = _normalize_index_arg (warp_size )
367+ if warp_size_expr is None :
368+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_warp_idx" ))
369+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_warp_idx" ), warp_size_expr )
370+
371+
372+ def get_warp_group_idx (
373+ warp_size : int | PrimExpr | None = None ,
374+ warps_per_group : int | PrimExpr | None = None ,
375+ ) -> PrimExpr :
376+ """Return the canonical warp group index for the calling thread.
377+
378+ Parameters
379+ ----------
380+ warp_size : Optional[int, PrimExpr]
381+ Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).
382+ warps_per_group : Optional[int, PrimExpr]
383+ Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.
384+
385+ Example
386+ -------
387+ >>> group = T.get_warp_group_idx()
388+ >>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group
389+
390+ Implementation Notes
391+ --------------------
392+ Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which
393+ divides the block-linear thread id by `warp_size * warps_per_group`,
394+ matching the canonical ordering while allowing architecture-specific overrides.
395+ """
396+ warp_size_expr = _normalize_index_arg (warp_size )
397+ warps_per_group_expr = _normalize_index_arg (warps_per_group )
398+ args = []
399+ if warp_size_expr is not None :
400+ args .append (warp_size_expr )
401+ if warps_per_group_expr is not None :
402+ if warp_size_expr is None :
403+ raise ValueError ("get_warp_group_idx expects `warp_size` when specifying "
404+ "`warps_per_group`." )
405+ args .append (warps_per_group_expr )
406+ return tir .call_intrin ("int32" , tir .op .Op .get ("tl.get_warp_group_idx" ), * args )
407+
408+
409+ def shuffle_elect (thread_extent : int ) -> PrimExpr :
410+ """Elect exactly one lane within a logical thread group.
411+
412+ Parameters
413+ ----------
414+ thread_extent : int
415+ Size (in threads) of the group in which a single lane should be elected.
416+ Passing 0 elects a single lane in the entire thread block.
417+
418+ Example
419+ -------
420+ >>> is_leader = T.shuffle_elect(64)
421+ >>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))
422+
423+ Implementation Notes
424+ --------------------
425+ Lowered to the CUDA helper `tl::tl_shuffle_elect<thread_extent>()` defined in
426+ `src/tl_templates/cuda/intrin.h`, which relies on
427+ `cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or
428+ `__shfl_sync`) to pick one lane per group.
429+ """
430+ return tir .call_intrin ("bool" , tir .op .Op .get ("tl.tl_shuffle_elect" ), thread_extent )
431+
432+
299433def warpgroup_fence_operand (buffer_or_ptr : Buffer | PrimExpr ,
300434 offset : int | PrimExpr = 0 ,
301435 num_regs : int | PrimExpr | None = None ,
@@ -563,3 +697,14 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
563697 """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
564698 """
565699 return tir .call_intrin ("handle" , tir .op .Op .get ("tl.ptx_cp_async_barrier_noinc" ), barrier_id )
700+
701+
702+ def tcgen05_mma_arrive (mbar_ptr ):
703+ """Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
704+
705+ Parameters
706+ ----------
707+ mbar_ptr : PrimExpr
708+ Pointer to the mbarrier object in shared memory (e.g., Barrier*).
709+ """
710+ return tir .call_intrin ("void" , tir .op .Op .get ("tl.tcgen05_mma_arrive" ), mbar_ptr )
0 commit comments