Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add paged-attention #20

Merged
merged 5 commits into from
Mar 7, 2024
Merged

add paged-attention #20

merged 5 commits into from
Mar 7, 2024

Conversation

kaiyuanm
Copy link
Contributor

  1. Add paged attention operation commonly used in large language model inference. About paged attention view https://arxiv.org/pdf/2309.06180.pdf.
  2. This implementation requires triton 2.2.0.

There is some performance data between triton and vllm.

vllm_paged_attention-B32-G8-D128-bs16-v2:
   context_len     triton  vllm-0.3.0
0        512.0   6.618643    3.582180
1       1024.0   7.798634    3.925020
2       2048.0   8.560921    4.286430
3       4096.0  12.241113    4.460762
4       8192.0  13.589705    4.558541
5      16384.0  14.094599    4.609148

Note:

  • Input layout is different, performance is for reference only.
  • Some input shapes is still being optimized.

elif num_splits > 1:
partition_size = triton.cdiv(max_context_len, num_splits)
partition_size = triton.next_power_of_2(partition_size)
assert partition_size >= kv_block_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else assert num_splits == 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to flash_attn flash_attn_with_kvcache definition.

num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here you need to guard num_splits against negatives

kv_mask = mask_offset[:, None] < context_len

# k: [KV_BLOCK_SIZE, HEAD_SIZE]
k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to omit kv_mask here and the mask the output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like ok. What do you think? @iclementine

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wehen loading, the mask is obligatory if it would access illegal memory otherwise. However you can use the modulo trick to avoid masking. Whether it is more efficient than masking depends on the cost .

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@iclementine iclementine merged commit b0045fb into FlagOpen:main Mar 7, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants