Skip to content

Conversation

@LeiWang1999
Copy link
Member

This pull request introduces new examples for matrix multiplication kernels using the tilelang framework, along with enhancements to the GEMM layout and bulk copy operations in the core codebase. The most significant changes include the addition of new example files showcasing specialized GEMM implementations, improvements to layout inferencing for GEMM operations, and updates to bulk copy logic to support enhanced indexing and memory access.

New Examples for Specialized GEMM Kernels:

  • Added new example files (example_warp_specialize_gemm.py, example_warp_specialize_gemm_barrier4.py, example_warp_specialize_gemm_stage2.py, and example_warp_specialize_mla.py) demonstrating various matrix multiplication techniques with features like pipelining, barriers, and multi-stage execution. These examples include Python and PyTorch integration for testing and profiling. [1] [2] [3]

Enhancements to GEMM Layouts:

  • Updated Gemm::InferLayout in src/op/gemm.cc to dynamically determine matrix dimensions for layout inference, ensuring compatibility with varying tensor shapes. [1] [2] [3]
  • Modified Gemm constructor to store raw buffer pointers (Aptr, Bptr, Cptr) and use them in lowering, improving clarity and flexibility. [1] [2]

Improvements to Bulk Copy Operations:

  • Enhanced LowerBulkCopy in src/op/bulk_copy.cc to calculate memory offsets and strides for shared tensors, enabling more precise memory access during bulk copy operations.
  • Adjusted thread synchronization logic in LowerBulkCopy and Conv2DIm2ColOp::Lower to account for thread bounds, improving correctness in multi-threaded scenarios. [1] [2]

Minor Improvements:

  • Added detailed error messages for layout checks in makeFullBankSwizzleLayout to aid debugging.

…ogic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.
…nd related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.
* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.
…r MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.
* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.
…ency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.
…ctions in `builtin.py` for improved code readability and separation of logical sections.
…ation

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.
* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.
…ction

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.
@LeiWang1999 LeiWang1999 merged commit f617c58 into tile-ai:main May 3, 2025
3 checks passed
@LeiWang1999 LeiWang1999 deleted the feat_20250503_specialize branch May 3, 2025 10:56
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 18, 2025
… pass (tile-ai#447)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
… pass (tile-ai#447)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant