Skip to content

Conversation

@chengyupku
Copy link
Contributor

This pull request includes significant updates to the examples/deepseek_mla project, focusing on enhancing the documentation and optimizing the example_mla_decode.py script. The most important changes include the addition of a detailed README for MLA, the introduction of an argument parser for better script configurability, and various performance optimizations in the kernel implementation.

Documentation improvements:

  • examples/deepseek_mla/README.md: Added a comprehensive guide on writing high-performance kernels with TileLang, using MLA as an example. The guide covers MLA introduction, benchmark results, implementation details, and various optimization techniques like threadblock swizzling, shared memory swizzling, warp-specialization, and pipelining.

Code enhancements:

These changes aim to enhance the performance and usability of the MLA decoding example, making it easier to configure and more efficient in execution.

…HA WGMMA pipelined example (FA3-like scheduling)

This commit introduces a new transformation pass `RewriteWgmmaSync` to optimize warp group matrix multiply accumulate (WGMMA) operations in the TileLang compiler:

- Implemented `WgmmaSyncRewriter` in `src/transform/wgmma_sync_rewriter.cc`
- Added pass registration for `RewriteWgmmaSync`
- Updated `tilelang/engine/phase.py` to include the new transformation pass
- Updated `tilelang/transform/__init__.py` to expose the new pass

The rewriter intelligently manages synchronization and dependencies between WGMMA operations, improving pipeline efficiency for complex matrix multiplication kernels.
Improve thread tag validation in warp specialized rewriter to prevent unintended transformations:
- Add more precise checks for threadIdx.y and threadIdx.z
- Validate thread extent to ensure only single-extent thread bindings are allowed
- Prevent warp specialization for multi-extent thread bindings in y and z dimensions
…lash Attention Implementations

- Add new `flash_attn` macro for non-split flash attention implementation
- Add swizzled layout for tile in shared memory
- Use threadblock swizzle to imporve L2 cache hit rate
…nce Benchmarks

- Add detailed README.md explaining MLA (Multi-Head Latent Attention) implementation
- Include performance benchmark images for batch sizes 64 and 128
- Add layout visualization images for QK and PV operations
- Implement torch reference implementations in torch_refs.py
- Update example_mla_decode.py with command-line argument support and flexible configuration
- Add performance benchmarking and comparison with other implementations
@chengyupku chengyupku requested a review from LeiWang1999 March 3, 2025 17:20
Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, Merged :)

@chengyupku chengyupku merged commit 0bbd063 into tile-ai:main Mar 3, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants