Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][RFC] Refactor the punica kernel based on Triton #5036

Merged
merged 103 commits into from
Aug 1, 2024

Conversation

jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented May 24, 2024

FILL IN THE PR DESCRIPTION HERE

Motivation

LoRA is highly favored within the vLLM community, there are numerous LoRA-related issues and pull requests. Thanks for @Yard1 great work, we have supported the Mutil-LoRA . However, there are several issues with the punica kernel (bgmv):

How to resolve

As descibed in title,I plan to resolve these issues by refactoring the punica kernel based on Triton,the main reason is Triton provides a hardware agnostic way of programming and targeting GPUs, currently supporting both NVIDIA and AMD. I will implement punica triton kernel , and:

  • For shrink kernel, I will implement it based on GroupedGEMM and SPLIT-K
  • For expand kernel, I will implement it based on GroupedGEMM

Work plan

  • Implement SGMV kernel(shrink+expand)
  • Replace punica bgmv kernel with sgmv trtion kenel
  • add unit test for sgmv trtion kenel

If my method is approved, I will spend one week to complete this PR, and tuning of the Triton kernel will be completed in the subsequent PR. Currently, I want to focus on replacing the punica bgmv kernel with the sgmv Triton kernel.

Please leave any feedback, any feedback will be appreciated.

Summary [06/27]

  • I have completed the development of the Triton kernel and achieved SGMV (shrink/expand) and BGMV (shrink/expand)
  • I have completed the integration, and unit testing of the Triton kernel according to my approach. TP(2/4/8)/CUDA Graph/Quantization/Fully sharded LoRA/NVIDIA T4, and so forth, all yield correct results. IMHO, the above issues were addressed or alleviated(performance drop). the details as follow table.
  • This is just a start, there's more work to be done , such as further improving kernel performance and auto-tuning configurations,I believe these can be considered for future PR
  • Regardless of the final decision made by the VLLM team, I really enjoy the process of addressing these issues

cc @zhuohan123 @Yard1 @WoosukKwon @simon-mo @mgoin @FurtherAI

7e6703da00701ca09b540fd2c5221267
9c150f184469c97f39208e7eb713ae8d

@jeejeelee jeejeelee marked this pull request as draft May 24, 2024 17:33
@jeejeelee
Copy link
Collaborator Author

cc @Yard1 @WoosukKwon

@Yard1
Copy link
Collaborator

Yard1 commented May 24, 2024

@jeejeelee can you coordinate with @FurtherAI who is working on something similar? #5025

@jeejeelee
Copy link
Collaborator Author

@Yard1 How can I coordinate with @FurtherAI?

I'm happy to work together to push this feature forward, but it seems like we have different approaches to kernel implementation. Which one should we choose?

@Yard1
Copy link
Collaborator

Yard1 commented May 24, 2024

@jeejeelee Would it be possible for you two to talk together and pick the best approach? I can connect you two, or you can use the comments here

@FurtherAI
Copy link
Contributor

We battle it out. We can take then best of both or just select one if they are close. I do have a good portion of the integration with vLLM written, but untested, so it would be good if that still applied in the end.

@jeejeelee First, can you tell me about any limitations or assumptions? Lora weight format (paged or padded), do the kernels operate on each lora at their rank or at a max rank, do you need a minimum group size per lora (tl.dot requires shapes of 16 or greater so do you pad if there aren't enough tokens for a lora?), do you have any benchmarking, do you have any tests for correctness or anything else you think is important?

@jeejeelee
Copy link
Collaborator Author

jeejeelee commented May 25, 2024

@FurtherAI Let's cheers to wonderful similarities, I'm glad to see like-minded people to address these issues. Here is my response:

  • can you tell me about any limitations or assumptions?

    My version is designed to be compatible with the current vLLM's LoRA format. If I am understand correctly, vLLM's LoRA format is padded, if paged format is the memory for different LoRA of the same layer is not contiguous, my version also compatible with the paged format by modifying some codes.

  • do the kernels operate on each lora at their rank or at a max rank, do you need a minimum group size per lora (tl.dot requires shapes of 16 or greater so do you pad if there aren't enough tokens for a lora?)

    My version operates at maximum rank, taking into consideration vLLM's LoRA format,my version should also be compatible with operating at each lora rank. I also need ensure that the smallest BLOCK_N is 16,but only guaranteeing this is enough,although the number of rank is 1

  • do you have any benchmarking

    I have not conducted a systematic benchmark, I have only compared with the current vLLM's bgmv using NSYS, and the performance improvement is very obvious. I will complete the relevant benchmark as soon as possible. What's more ,I have also implemented the bgmv triton kernel for the decoding stage. However, whether we need to replace sgmv with bgmv in the decoding stage should be decided after benchmarking.

  • do you have any tests for correctness

    Yes, I have not only tested with PyTorch, but I have also tested with vLLM's bgmv, for example. the above issue kernel grid exceeding 65535 was found when testing with vLLM's bgmv.

  • anything else you think is important?

    IMHO, Resloving these above issues is most important as soon as possilble.

I really enjoy our conversation, which deepened my understanding of multi-lora. Thanks

Cheers

@FurtherAI
Copy link
Contributor

@jeejeelee thanks for the detailed response. I think you understand paged format correctly. So my kernels run on a paged lora format, like in S-LoRA. This of course allows for more active loras at a time. They also run at the rank of each lora which provides some solid speed ups, especially at higher sequence length. And they're compatible with any rank, down to 1 without padding, which means they can run fast with arbitrary grouping.

Can you give me some sample code to run your kernels so I can benchmark them? If yours has a significant speed advantage over mine, then we should make changes, but if not, mine are significantly more developed and there would be no reason to duplicate the effort.

@jeejeelee
Copy link
Collaborator Author

@FurtherAI I apologize for the late reply due to the weekend. I provide some code in temp_test.py

@FurtherAI
Copy link
Contributor

No worries, thanks for the code. I'll try to get back to you soon.

@zhuohan123
Copy link
Member

This is great! The punica kernels also take a very long time to build and take up space in our wheels. Moving them to triton can make the building process faster and reduce the wheel size.

@leiwen83
Copy link
Contributor

leiwen83 commented May 29, 2024

Currently this imp still has two kernel dealing with shrink and expand separately. I wonder whether we could merge them into one? So that triton could do the pipeline autotune for this fused kernel?

And take one step further, whether we could have kernel fuse x@W+x@A@B, so that could make load inference at decoding stage cost reduced?

@jeejeelee
Copy link
Collaborator Author

jeejeelee commented May 29, 2024

Currently this imp still has two kernel dealing with shrink and expand separately. I wonder whether we could merge them into one? So that triton could do the pipeline autotune for this fused kernel?

And take one step further, whether we could have kernel fuse x@W+x@A@B, so that could make load inference at decoding stage cost reduced?

Hi, thanks for your attention and suggest.

Of course, we can fuse these two back-to-back GEMM kernels as done in b2b_grouped_gemm.Actually, implementing this in a single kernel is my first idea, but I have not yet figured out how to map split-k, batch_size, and M, N into grid. Therefore, I am currently implementing shrink and expand separately. In the future, I will continue to explore this. For now, my focus is on replacing the current bgmv and addressing the issues mentioned above.

Regarding "pipeline autotune for this fused kernel," I'm not quite sure what this means—did you mean to refer to the num_stage? If I understand correctly, I think if your GPU sm>8.0, trtion will invoking pipeline optimization,regardless of whether fusion is involved.

For fuse x@W+x@A@B, if we only have one LoRA, we can merge the LoRA's weight into W first and then only need to compute x@W. However, since we need to support Multi-LoRA, I think fusing everything into one kernel might not be the best choice. we also need to consider factors such as register pressure.

@FurtherAI
Copy link
Contributor

FurtherAI commented May 29, 2024

@jeejeelee So it looks like our kernels accomplish two partly different goals.

Yours can function as a drop in replacement for the current Punica kernels. I have done some benchmarking (which I'll attach) and assuming I understand how to bench yours correctly and it is doing what I expect, then I think yours have pretty comparable speed to Punica across the board.

Mine has two major differences, loras in a paged memory format and computing at the rank of each lora. Together these allow for more loras at a time and assuming loras are randomly distributed in rank, they get pretty fast and it makes up for the cost of loading the extra indexing information for having them in a page. Your strategy of looping along the rank dimension didn't give any speedup to mine, so I expect any differences are loading the indexing and the different format.

Based on the benchmarks and assuming lora ranks are distributed, these kernels don't offer a significant speed advantage over mine. Though they would probably be faster than mine if you add the ability to compute at the actual rank (lacking the paged loras though). It would be reasonable to use yours as a drop in replacement for Punica until I push mine or keep both and use mine if the user wants more active loras. Or it is reasonable to just wait for the kernels I've developed to be fully integrated since these kernels will likely be a marginal improvement.

Up to the vLLM team what they want to do.
I think the expand with repeats 8 is an anomaly here, it matches the shrink kernel on other machines.
bench_triton_kernels.txt

@jeejeelee
Copy link
Collaborator Author

jeejeelee commented May 29, 2024

@jeejeelee So it looks like our kernels accomplish two partly different goals.

Yours can function as a drop in replacement for the current Punica kernels. I have done some benchmarking (which I'll attach) and assuming I understand how to bench yours correctly and it is doing what I expect, then I think yours have pretty comparable speed to Punica across the board.

Mine has two major differences, loras in a paged memory format and computing at the rank of each lora. Together these allow for more loras at a time and assuming loras are randomly distributed in rank, they get pretty fast and it makes up for the cost of loading the extra indexing information for having them in a page. Your strategy of looping along the rank dimension didn't give any speedup to mine, so I expect any differences are loading the indexing and the different format.

Based on the benchmarks and assuming lora ranks are distributed, these kernels don't offer a significant speed advantage over mine. Though they would probably be faster than mine if you add the ability to compute at the actual rank (lacking the paged loras though). It would be reasonable to use yours as a drop in replacement for Punica until I push mine or keep both and use mine if the user wants more active loras. Or it is reasonable to just wait for the kernels I've developed to be fully integrated since these kernels will likely be a marginal improvement.

Up to the vLLM team what they want to do. I think the expand with repeats 8 is an anomaly here, it matches the shrink kernel on other machines. bench_triton_kernels.txt

Thanks for your work benckmark. It seems that you mainly want to support S-LoRA, so you implemented these kernels.

Currently, I am following the above work plan and attempting to replace bgmv with my operator (progress 80%). Additionally, I have also added bgmv-related triton kernel(these kernels have not been fully tested yet). Regardless, I will complete my work plan.

Cheers

@jeejeelee
Copy link
Collaborator Author

jeejeelee commented May 31, 2024

Currently, while the sgmv I've implemented can achieve high performance long sequence scenarios, it falls short compared to Punica's bgmv in cases involving small batches and short sequences. I'm working on improving its performance in these scenarios.

Additionally, I've also implemented bgmv for use in the decoding stage , but it's still in its preliminary stages of development.

I've already implemented some code to replace bgmv with triton sgmv , mainly by adopting the previous shallow copy approach to pass the parameters required by the sgmv . However, I'm not sure if this approach is reasonable, could I trouble you to provide guidance and suggestions? @Yard1 ,
batch_size1_128

batch_size4_2048

@jeejeelee
Copy link
Collaborator Author

@K-Mistele Thank you for your testing, I really appreciate it. By the way, the current vLLM team has released version 0.5.4, and this PR was integrated into it. You can run pip install vllm directly.

kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
@Rachneet
Copy link

@jeejeelee the latest VLLM version (0.5.4) does not support lora_rank> 64

@jeejeelee
Copy link
Collaborator Author

@jeejeelee the latest VLLM version (0.5.4) does not support lora_rank> 64

#7146 had addressed this issue. However, you'll have to wait until version 0.5.5 for support, or you can build it from the source

@Rachneet
Copy link

Got it. Thanks for your work. @jeejeelee

@askcs517
Copy link

question1:
how to understand the following description?
1.Only support GPU which is compute capability >= 8.0
2.I have completed the integration, and unit testing of the Triton kernel according to my approach. TP(2/4/8)/CUDA Graph/Quantization/Fully sharded LoRA/NVIDIA T4,
T4's capability >= 7.5

question2:
in T4 and vllm version==0.4.3
deploy multi lora in vllm failed, error info:
RuntimeError: CUDA error: no kernel image is available for execution on the device.
my deploy command:
CUDA_VISIBLE_DEVICES=0,1,2,3 swift deploy --tensor_parallel_size 4 --dtype fp16 --model_type qwen1half-7b-chat
--model_id_or_path /cloud/user/data/data0806/llm/M2/Chat_New
--ckpt_dir /cloud/user/data/data0806/llm/M2/checkpoint-200/
--infer_back vllm -- vllm_enable_lora true --max_model_len 512 --enforce_eager
I tried to vllm to 0.5.5, there is still error
@jeejeelee

@jeejeelee
Copy link
Collaborator Author

@askcs517

question1

These were previous issues with Punica. This pull request addresses some of them. Please carefully read the initial description again

question2

It's unclear what error you encountered. You should also consider submitting an issue.

@askcs517
Copy link

askcs517 commented Aug 28, 2024 via email

@sleepwalker2017
Copy link

sleepwalker2017 commented Oct 11, 2024

Thank you for the excellent work! @jeejeelee

I'm checking the triton kernels ,and I have two questions:

  1. There are 6 kernels, why there are so many ones? bgmv/sgmv expand/expand_slice, what's the difference between them?
  2. Is the kernel already tuned? I didn't find a config file for each platform.

Thank you.

@askcs517
Copy link

askcs517 commented Oct 11, 2024 via email

@jeejeelee
Copy link
Collaborator Author

Thank you for the excellent work! @jeejeelee

I'm checking the triton kernels ,and I have two questions:

  1. There are 6 kernels, why there are so many ones? bgmv/sgmv expand/expand_slice, what's the difference between them?
  2. Is the kernel already tuned? I didn't find a config file for each platform.

Thank you.

Q1:
The shrink is used for calculating lora_a, while the expand is used for calculating lora_b.

We also differentiate between the prefill and decode stage: the prefill stage uses the sgmv kernel , while the decode stage uses the bgmv kernel. The expand_slice is applicable to qkv linear and ffn linear operations in networks similar to LLaMA.

Q2:
not yet

@sleepwalker2017
Copy link

sleepwalker2017 commented Oct 12, 2024

Thank you for the excellent work! @jeejeelee
I'm checking the triton kernels ,and I have two questions:

  1. There are 6 kernels, why there are so many ones? bgmv/sgmv expand/expand_slice, what's the difference between them?
  2. Is the kernel already tuned? I didn't find a config file for each platform.

Thank you.

Q1: The shrink is used for calculating lora_a, while the expand is used for calculating lora_b.

We also differentiate between the prefill and decode stage: the prefill stage uses the sgmv kernel , while the decode stage uses the bgmv kernel. The expand_slice is applicable to qkv linear and ffn linear operations in networks similar to LLaMA.

Q2: not yet

Hi, I wonder what hyper-params should be tuned?
I think the block_K is optimal, because rank is small, so we can read it into SHM in a block once.

And as for the default config, it seems the configuration is already good enough?

Do you have any suggestions for tuning the kernels ? Thank you !

        return {
            "BLOCK_N": 256,
            "SPLIT_N": _check_divisibility(hidden_size),
            "num_warps": 8
        }
    # TODO tuning this config
    N, K = lora_b_weights.shape[-2:]  # K= rank,N=hidden_size
    BLOCK_K = triton.next_power_of_2(K)
    EVEN_K = K % BLOCK_K == 0
    ADD_INPUTS = add_inputs
    CAST_TYPE = False
    if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
            torch.float16,
            torch.bfloat16,
    ]:
        CAST_TYPE = True
    batches = lora_indices_tensor.size(0)
    config = get_lora_op_configs("expand", batches, N)
    grid = lambda META: (
        META["SPLIT_N"],
        batches,
    )
    _bgmv_expand_kernel[grid](
        inputs,
        lora_b_weights,
        output_tensor,
        N,
        K,
        lora_indices_tensor,
        inputs.stride(0),
        inputs.stride(1),
        lora_b_weights.stride(0),
        lora_b_weights.stride(1),
        lora_b_weights.stride(2),
        output_tensor.stride(0),
        output_tensor.stride(1),
        BLOCK_K=BLOCK_K,
        EVEN_K=EVEN_K,
        ADD_INPUTS=ADD_INPUTS,
        CAST_TYPE=CAST_TYPE,
        **config,
    )

@jeejeelee
Copy link
Collaborator Author

@sleepwalker2017 See: https://github.com/jeejeelee/punica_triton_kernel/blob/main/benchmark_bgmv.py

@sleepwalker2017
Copy link

We also differentiate between the prefill and decode stage: the prefill stage uses the sgmv kernel , while the decode stage uses the bgmv kernel. The expand_slice is applicable to qkv linear and ffn linear operations in networks similar to LLaMA.

The tuning program is fast and efficient. Does the config fit for sgmv kernels?

Or if I tune bgmv kernels for batch <= 4096, do they apply for sgmv kernels with sequence length <= 4096? Thank you.

@sleepwalker2017
Copy link

@sleepwalker2017 See: https://github.com/jeejeelee/punica_triton_kernel/blob/main/benchmark_bgmv.py

Sorry for so many questions, but do we need to tune for different num_stages?

@jeejeelee
Copy link
Collaborator Author

The tuning program is fast and efficient. Does the config fit for sgmv kernels?

The logic of tuning is the same, but some things may need to be modified.

but do we need to tune for different num_stages?

Perhaps we need.

@sleepwalker2017
Copy link

Hi, I have another question, for qkv_proj, why do we calculate the lora for them separately?
In my view, we can fuse the loras for qkv, and do the bgmv in a single kernel. Is that the fact?

@jeejeelee
Copy link
Collaborator Author

Hi, I have another question, for qkv_proj, why do we calculate the lora for them separately? In my view, we can fuse the loras for qkv, and do the bgmv in a single kernel. Is that the fact?

Yes, you are right. In fact, we aslo can fuse lora_a and lora_b into one kernel

@sleepwalker2017
Copy link

Hi, why does shrink kernel uses tl.float32 as output dtype? I notice the bgmv Slora kernel uses the same dtype for input and output.

@jeejeelee
Copy link
Collaborator Author

Hi, why does shrink kernel uses tl.float32 as output dtype? I notice the bgmv Slora kernel uses the same dtype for input and output.

Due to the shrink kernel utilizing atomic add and Triton's lack of support for BF16 atomic add , see: triton-lang/triton#1387

@sleepwalker2017
Copy link

triton-lang/triton#1387

got it. so it's ok to use 16-bit result when using float16. Thank you.

@sleepwalker2017
Copy link

sleepwalker2017 commented Nov 7, 2024

Hi, I find we can delete the expand_slice kernels, since we can slice the tensor before calling the expand kernel.

I did the following changes, the result has no difference. diff.max() is always 0. Is that right?

        self.add_shrink(buffer, x, wa_t_all, scale)
        if y_offset is None and y_slice_size is None:
            self.add_expand(y, buffer, wb_t_all, add_input=True)
        else:
            tmp = torch.clone(y)
            self.add_expand_slice(y,
                                  buffer,
                                  wb_t_all,
                                  y_offset,
                                  y_slice_size,
                                  add_input=True)

            self.add_expand(tmp[:, y_offset:y_offset+y_slice_size], buffer, wb_t_all, add_input=True)
            diff = torch.abs(tmp - y)
            print(diff.max().item(), "diff")
            if diff.max().item() > 0:
                import pdb; pdb.set_trace()

Hope to get your confirmation.

@jeejeelee
Copy link
Collaborator Author

Hi, I find we can delete the expand_slice kernels, since we can slice the tensor before calling the expand kernel.

I did the following changes, the result has no difference. diff.max() is always 0. Is that right?

        self.add_shrink(buffer, x, wa_t_all, scale)
        if y_offset is None and y_slice_size is None:
            self.add_expand(y, buffer, wb_t_all, add_input=True)
        else:
            tmp = torch.clone(y)
            self.add_expand_slice(y,
                                  buffer,
                                  wb_t_all,
                                  y_offset,
                                  y_slice_size,
                                  add_input=True)

            self.add_expand(tmp[:, y_offset:y_offset+y_slice_size], buffer, wb_t_all, add_input=True)
            diff = torch.abs(tmp - y)
            print(diff.max().item(), "diff")
            if diff.max().item() > 0:
                import pdb; pdb.set_trace()

Hope to get your confirmation.

Yes, it seems reasonable based on initial review. However, it doesn't help much with current performance. Could we go further and only call the shrink and expand op once for the lora layers that call nslice?

@sleepwalker2017
Copy link

Hi, I find we can delete the expand_slice kernels, since we can slice the tensor before calling the expand kernel.
I did the following changes, the result has no difference. diff.max() is always 0. Is that right?

        self.add_shrink(buffer, x, wa_t_all, scale)
        if y_offset is None and y_slice_size is None:
            self.add_expand(y, buffer, wb_t_all, add_input=True)
        else:
            tmp = torch.clone(y)
            self.add_expand_slice(y,
                                  buffer,
                                  wb_t_all,
                                  y_offset,
                                  y_slice_size,
                                  add_input=True)

            self.add_expand(tmp[:, y_offset:y_offset+y_slice_size], buffer, wb_t_all, add_input=True)
            diff = torch.abs(tmp - y)
            print(diff.max().item(), "diff")
            if diff.max().item() > 0:
                import pdb; pdb.set_trace()

Hope to get your confirmation.

Yes, it seems reasonable based on initial review. However, it doesn't help much with current performance. Could we go further and only call the shrink and expand op once for the lora layers that call nslice?

Yes, I'm working on the performance of the lora kernels, and I'll try that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.