-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][RFC] Refactor the punica kernel based on Triton #5036
Conversation
@jeejeelee can you coordinate with @FurtherAI who is working on something similar? #5025 |
@Yard1 How can I coordinate with @FurtherAI? I'm happy to work together to push this feature forward, but it seems like we have different approaches to kernel implementation. Which one should we choose? |
@jeejeelee Would it be possible for you two to talk together and pick the best approach? I can connect you two, or you can use the comments here |
We battle it out. We can take then best of both or just select one if they are close. I do have a good portion of the integration with vLLM written, but untested, so it would be good if that still applied in the end. @jeejeelee First, can you tell me about any limitations or assumptions? Lora weight format (paged or padded), do the kernels operate on each lora at their rank or at a max rank, do you need a minimum group size per lora (tl.dot requires shapes of 16 or greater so do you pad if there aren't enough tokens for a lora?), do you have any benchmarking, do you have any tests for correctness or anything else you think is important? |
@FurtherAI Let's cheers to wonderful similarities, I'm glad to see like-minded people to address these issues. Here is my response:
I really enjoy our conversation, which deepened my understanding of multi-lora. Thanks Cheers |
@jeejeelee thanks for the detailed response. I think you understand paged format correctly. So my kernels run on a paged lora format, like in S-LoRA. This of course allows for more active loras at a time. They also run at the rank of each lora which provides some solid speed ups, especially at higher sequence length. And they're compatible with any rank, down to 1 without padding, which means they can run fast with arbitrary grouping. Can you give me some sample code to run your kernels so I can benchmark them? If yours has a significant speed advantage over mine, then we should make changes, but if not, mine are significantly more developed and there would be no reason to duplicate the effort. |
@FurtherAI I apologize for the late reply due to the weekend. I provide some code in |
No worries, thanks for the code. I'll try to get back to you soon. |
This is great! The punica kernels also take a very long time to build and take up space in our wheels. Moving them to triton can make the building process faster and reduce the wheel size. |
Currently this imp still has two kernel dealing with shrink and expand separately. I wonder whether we could merge them into one? So that triton could do the pipeline autotune for this fused kernel? And take one step further, whether we could have kernel fuse x@W+x@A@B, so that could make load inference at decoding stage cost reduced? |
Hi, thanks for your attention and suggest. Of course, we can fuse these two back-to-back GEMM kernels as done in b2b_grouped_gemm.Actually, implementing this in a single kernel is my first idea, but I have not yet figured out how to map Regarding "pipeline autotune for this fused kernel," I'm not quite sure what this means—did you mean to refer to the For fuse |
@jeejeelee So it looks like our kernels accomplish two partly different goals. Yours can function as a drop in replacement for the current Punica kernels. I have done some benchmarking (which I'll attach) and assuming I understand how to bench yours correctly and it is doing what I expect, then I think yours have pretty comparable speed to Punica across the board. Mine has two major differences, loras in a paged memory format and computing at the rank of each lora. Together these allow for more loras at a time and assuming loras are randomly distributed in rank, they get pretty fast and it makes up for the cost of loading the extra indexing information for having them in a page. Your strategy of looping along the rank dimension didn't give any speedup to mine, so I expect any differences are loading the indexing and the different format. Based on the benchmarks and assuming lora ranks are distributed, these kernels don't offer a significant speed advantage over mine. Though they would probably be faster than mine if you add the ability to compute at the actual rank (lacking the paged loras though). It would be reasonable to use yours as a drop in replacement for Punica until I push mine or keep both and use mine if the user wants more active loras. Or it is reasonable to just wait for the kernels I've developed to be fully integrated since these kernels will likely be a marginal improvement. Up to the vLLM team what they want to do. |
Thanks for your work benckmark. It seems that you mainly want to support S-LoRA, so you implemented these kernels. Currently, I am following the above work plan and attempting to replace bgmv with my operator (progress 80%). Additionally, I have also added bgmv-related triton kernel(these kernels have not been fully tested yet). Regardless, I will complete my work plan. Cheers |
Currently, while the sgmv I've implemented can achieve high performance long sequence scenarios, it falls short compared to Punica's bgmv in cases involving small batches and short sequences. I'm working on improving its performance in these scenarios. Additionally, I've also implemented bgmv for use in the decoding stage , but it's still in its preliminary stages of development. I've already implemented some code to replace bgmv with triton sgmv , mainly by adopting the previous shallow copy approach to pass the parameters required by the sgmv . However, I'm not sure if this approach is reasonable, could I trouble you to provide guidance and suggestions? @Yard1 , |
@K-Mistele Thank you for your testing, I really appreciate it. By the way, the current vLLM team has released version 0.5.4, and this PR was integrated into it. You can run |
@jeejeelee the latest VLLM version (0.5.4) does not support lora_rank> 64 |
#7146 had addressed this issue. However, you'll have to wait until version |
Got it. Thanks for your work. @jeejeelee |
question1: question2: |
These were previous issues with Punica. This pull request addresses some of them. Please carefully read the initial description again
It's unclear what error you encountered. You should also consider submitting an issue. |
您好,您的邮件我已经收到,我会及时查收。麻烦您了。
|
Thank you for the excellent work! @jeejeelee I'm checking the triton kernels ,and I have two questions:
Thank you. |
您好,您的邮件我已经收到,我会及时查收。麻烦您了。
|
Q1: We also differentiate between the Q2: |
Hi, I wonder what hyper-params should be tuned? And as for the default config, it seems the configuration is already good enough? Do you have any suggestions for tuning the kernels ? Thank you !
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
) |
The tuning program is fast and efficient. Does the config fit for sgmv kernels? Or if I tune bgmv kernels for batch <= 4096, do they apply for sgmv kernels with sequence length <= 4096? Thank you. |
Sorry for so many questions, but do we need to tune for different num_stages? |
The logic of tuning is the same, but some things may need to be modified.
Perhaps we need. |
Hi, I have another question, for qkv_proj, why do we calculate the lora for them separately? |
Yes, you are right. In fact, we aslo can fuse lora_a and lora_b into one kernel |
…t#5036) Signed-off-by: Alvant <alvasian@yandex.ru>
Hi, why does shrink kernel uses tl.float32 as output dtype? I notice the bgmv Slora kernel uses the same dtype for input and output. |
Due to the shrink kernel utilizing atomic add and Triton's lack of support for BF16 atomic add , see: triton-lang/triton#1387 |
got it. so it's ok to use 16-bit result when using float16. Thank you. |
Hi, I find we can delete the expand_slice kernels, since we can slice the tensor before calling the expand kernel. I did the following changes, the result has no difference. diff.max() is always 0. Is that right?
Hope to get your confirmation. |
Yes, it seems reasonable based on initial review. However, it doesn't help much with current performance. Could we go further and only call the shrink and expand op once for the lora layers that call nslice? |
Yes, I'm working on the performance of the lora kernels, and I'll try that. |
FILL IN THE PR DESCRIPTION HERE
Motivation
LoRA is highly favored within the vLLM community, there are numerous LoRA-related issues and pull requests. Thanks for @Yard1 great work, we have supported the Mutil-LoRA . However, there are several issues with the punica kernel (bgmv):
Performance drop refer to :[Performance] 40% performance drop using lora vs no lora #2829
We need to modify the punica config to support the new dimension(such as fix-bgmv-kernel-640 #4007, [Kernel] Add punica dimension for Baichuan-13B #4053 , [Kernel] Add punica dimension for Chinese-Mixtral #4063)
The punica kernel cannot support certain dimensions, such as
3424
(for example, in ChatGLM3, the ffn_hidden_size is 13696, and when using tp=4, the ffn_hidden_size becomes 3424), because 3424 cannot be evenly divided by 64, which will cause astatic_assert
failure. refer to :[Bug]: RuntimeError: No suitable kernel. h_in=16 h_out=3424 dtype=Float out_dtype=BFloat16 #3793 [Bug]: Running the punica lora on Qwen1.5 32B model encountered RuntimeError: No suitable kernel. h_in=64 h_out=3424 dtype=Float out_dtype=BFloat16 #4708Only support GPU which is compute capability >= 8.0
Other issues, such as increased compile time and the release wheel size( I'm not entirely sure if I understand correctly), as well as the potential need to support long context, may result in the second dimension of the kernel grid exceeding
65535
, leading to kernel launch failure.How to resolve
As descibed in title,I plan to resolve these issues by refactoring the punica kernel based on Triton,the main reason is Triton provides a hardware agnostic way of programming and targeting GPUs, currently supporting both NVIDIA and AMD. I will implement punica triton kernel , and:
Work plan
If my method is approved, I will spend one week to complete this PR, and tuning of the Triton kernel will be completed in the subsequent PR. Currently, I want to focus on replacing the punica bgmv kernel with the sgmv Triton kernel.
Please leave any feedback, any feedback will be appreciated.
Summary [06/27]
cc @zhuohan123 @Yard1 @WoosukKwon @simon-mo @mgoin @FurtherAI