Skip to content

Conversation

@LeiWang1999
Copy link
Member

This pull request includes significant changes to the examples/flash_attention/example_mha_fwd_bhsd.py file and introduces a new file examples/seer_attention/block_sparse_attn_tilelang.py. The changes focus on improving the flash attention mechanism and adding support for block sparse attention.

Improvements to flash attention:

  • Modified flashattn function to accept separate query and key/value sequence lengths (seq_q and seq_kv) instead of a single sequence length (seq_len).
  • Updated buffer shapes and kernel functions to accommodate separate seq_q and seq_kv lengths. [1] [2] [3]
  • Adjusted the reference program and argument parsing to use the new sequence length parameters. [1] [2]

Introduction of block sparse attention:

  • Added a new file examples/seer_attention/block_sparse_attn_tilelang.py implementing block sparse attention using TileLang.
  • Defined functions to create sparse attention masks based on top-k values and thresholds.
  • Implemented the blocksparse_flashattn function to perform block sparse attention, including the necessary kernel functions and macros.
  • Included test functions to verify the correctness of the block sparse attention implementation.

…end Support

- Update `example_mha_fwd_varlen.py` to use Cython backend for kernel compilation
- Remove unused imports and simplify function signature
- Modify `flashattn` function to handle max sequence length as a separate argument
- Update kernel call to include max sequence length parameter
- Improve code readability and remove commented-out code
- Add print statement to confirm successful assertion
- Improve line breaks and code formatting in `lower.py`, `wrapper.py`, and `tensor.py`
- Simplify line breaks and reduce unnecessary whitespace
- Enhance code readability by adjusting indentation and line breaks
- Update example MHA forward pass script with cleaner tensor initialization
…nd macro generator

- Modify import statements in test_tilelang_kernel_dequantize_gemm.py
- Replace bitblas imports with tilelang.intrinsics imports for MMA-related utilities
- Update main function to use tilelang.testing.main()
- Implement block sparse attention kernels for both TileLang and Triton
- Add utility functions for generating sparse attention masks using top-k and threshold methods
- Support causal and variable-length attention scenarios
- Include test cases for different sequence length configurations
- Demonstrate block-level sparse attention with configurable parameters
- Improve code formatting in block_sparse_attn_tilelang.py and block_sparse_attn_triton.py
- Enhance readability by adjusting line breaks and indentation
- Simplify kernel and function calls with better formatting
- Add whitespace and line break improvements for better code clarity
@LeiWang1999 LeiWang1999 merged commit 38d13c2 into tile-ai:main Mar 2, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant