-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add paged-attention #20
Conversation
src/flag_attn/paged.py
Outdated
elif num_splits > 1: | ||
partition_size = triton.cdiv(max_context_len, num_splits) | ||
partition_size = triton.next_power_of_2(partition_size) | ||
assert partition_size >= kv_block_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else assert num_splits == 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refer to flash_attn flash_attn_with_kvcache definition.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, here you need to guard num_splits against negatives
kv_mask = mask_offset[:, None] < context_len | ||
|
||
# k: [KV_BLOCK_SIZE, HEAD_SIZE] | ||
k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to omit kv_mask here and the mask the output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like ok. What do you think? @iclementine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wehen loading, the mask is obligatory if it would access illegal memory otherwise. However you can use the modulo trick to avoid masking. Whether it is more efficient than masking depends on the cost .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There is some performance data between triton and vllm.
Note: