Skip to content

Conversation

@LeiWang1999
Copy link
Member

This pull request introduces several new examples demonstrating specialized warp-level operations for matrix multiplication and attention mechanisms using the TileLang framework. These examples showcase different approaches to optimizing kernel execution with techniques like pipelining, barrier synchronization, and shared memory utilization.

Flash Attention Example:

  • Added a new example implementing the Flash Attention mechanism with warp-level specialization, including shared memory allocation, barrier synchronization, and swizzling for efficient computation. The implementation also includes a reference program and benchmarking for performance evaluation. (examples/warp_specialize/example_warp_specialize_flashmla.py)

Matrix Multiplication Examples:

  • Barrier-Pipelined Stage 2: Added an example demonstrating matrix multiplication with a two-stage barrier-pipelined approach, utilizing shared memory and mbarriers for efficient data movement and computation. (examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py)
  • Warp Specialize Copy 0, Gemm 1: Added an example showcasing a warp-specialized matrix multiplication where the copy operation precedes the GEMM operation. This example uses mbarriers and pipelining. (examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py)
  • Warp Specialize Copy 1, Gemm 0: Added an example with a reversed order of operations (GEMM precedes the copy operation) compared to the previous example, demonstrating another approach to warp specialization. (examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py)
  • Soft-Pipelined Stage 2: Added an example of matrix multiplication with a soft-pipelined approach using two stages, highlighting the use of pipelining and shared memory for performance optimization. (examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py)

…ogic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.
…nd related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.
* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.
…r MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.
* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.
…ency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.
…ctions in `builtin.py` for improved code readability and separation of logical sections.
…ation

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.
* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.
…ction

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.
* Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`.
* Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance.
* Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples.
* Updated the TileLang API to support these examples and improve kernel compilation and testing processes.
* Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities.
…ration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
@LeiWang1999 LeiWang1999 merged commit 5b8ab6d into tile-ai:main May 6, 2025
3 checks passed
@LeiWang1999 LeiWang1999 deleted the feat_250505_speclialize branch May 6, 2025 03:33
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 18, 2025
…ation (tile-ai#448)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix

* [Feature] Add new examples for warp specialization and TMA integration

* Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`.
* Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance.
* Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples.
* Updated the TileLang API to support these examples and improve kernel compilation and testing processes.
* Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities.

* lint fix

* Remove outdated example scripts for warp specialization and TMA integration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…ation (tile-ai#448)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix

* [Feature] Add new examples for warp specialization and TMA integration

* Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`.
* Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance.
* Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples.
* Updated the TileLang API to support these examples and improve kernel compilation and testing processes.
* Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities.

* lint fix

* Remove outdated example scripts for warp specialization and TMA integration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant