Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ include CMakeLists.txt
include requirements.txt
include requirements-test.txt
include requirements-dev.txt
include tilelang/jit/adapter/cython/cython_wrapper.pyx
recursive-include src *
recursive-include 3rdparty *
recursive-exclude 3rdparty/clang* *
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0
0.1.1
36 changes: 21 additions & 15 deletions examples/blocksparse_attention/block_sparse_attn_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,28 @@
import tilelang.language as T
import torch.nn.functional as F


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask


def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask


def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
Expand Down Expand Up @@ -136,7 +140,7 @@ def main(
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)

T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
Expand Down Expand Up @@ -165,6 +169,7 @@ def main(

return kernel_func(block_M, block_N, num_stages, threads)


def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
Expand All @@ -177,13 +182,15 @@ def test_topk_sparse_attention():
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)

sm_scale = 1.0 / (D_HEAD ** 0.5)
sm_scale = 1.0 / (D_HEAD**0.5)

# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16)
x_ds[:,:,:,0] = 100
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

# Run Triton kernel
Expand All @@ -194,25 +201,24 @@ def test_topk_sparse_attention():

# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(),
torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal

# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)

print("ref_output", ref_output)
print("tilelang_output", tilelang_output)


# Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
"TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")


if __name__ == "__main__":
test_topk_sparse_attention()
Loading
Loading