Skip to content

[Core] Enable Memory Tiering for vLLM #8694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

PanJason
Copy link

This is the following PR of #7697. This PR introduces 3 major functionalities:

  1. Context Caching with TTL support is also a huge drive for memory demand.
  2. Blockwise swapping for DRAM and Disk (Partially)
  3. Layered transfer between DRAM and HBM.
  4. Initial Disk Support

Some detailed explanations:

  1. Context caching:
    a. Usage:
    We add CachingParams to specify the TTL of the context caching. Within TTL we guarantee that the Context Caching will remain in the instance (won't be discarded)
    Example usage of offline inference:
caching_params = CachingParams(ttl=args.ttl)
cache_output = llm.caching(LONG_PROMPT, caching_params=caching_params)

Example usage of online serving (a new endpoint is added: v1/caching):

 url += "/v1/caching"
 payload["ttl"] = ttl

and send the payload to the URL

b. Implementation:
The scheduler.schedule() checks and frees the expired context caching
LLMEngine._process_model_outputs() moves the context caching requests to SequenceStatues.FIXED, which won't be freed within TTL

  1. Blockwise swapping for DRAM and Disk (partially)
    Usage:
    Add --enable-memory-tiering flag

Implementation:
Prefix matching is changed in 2 places: 1. can_allocate() now considers the prefix matching to calculate the necessary new blocks 2. In prefilling, we check whether the block resides in DRAM or Disk if so we fetch them in

Block allocation is changed for prefilling and decoding. When we allocate a new block that was computed, we check whether there is free space in DRAM and disk. We swap the block out if there is enough space. Note that this can also have a cascading effect (e.g. GPU -> CPU, CPU -> Disk) which we also handle

rmap is introduced to track which seq is mapped to the block since the context caching blocks have to exist during the TTL. The context caching only blocks are swappable (because if they are not used they will occupy the whole HBM) so we need to find the sequences that mapped to a certain block in eviction.

  1. Layered transfer between HBM and DRAM:
    Usage:
    Add --enable-layered-transfer flag

Implementation:
We create another CUDA stream for transferring. In attention calculation, the calculation waits for this dedicated CUDA stream and invokes the data transfer for the next layer in this dedicated CUDA stream. Then, it starts calculating the current layer.

  1. Disk support (Temtative)
    Usage:
    Add --enable-disk-swap flag and pass the config file by --disk-swap-config swap.cfg

Implementation
Right now disk still uses block abstract and the same granularity. The swap manager is separated into the stateful part (SwapDeviceManager) and stateless part (SwapDeviceClient) because the scheduler and the works use different processes, and it is hard to share states between them (we can do it. It Just takes more time). The stateful part manages the block allocation which the stateless part is responsible for transmission (Like the separation between CacheEngine and BlockManager).
Right now each block and each layer will create a file. I will optimize it later.

Supporting status:

  1. Only BlockManagerV1 is supported. BlockManagerV2 is WIP.
  2. Disk does not support layered transfer now as I am still exploring the interface
  3. Disk only supports TP=n case as PP=n requires sharing between multiple schedulers.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

PanJason and others added 12 commits September 20, 2024 10:16
Right now only the potential interface for layered transfer is added.
Because vLLM only supports one cuda stream the transmission and
computation is not really overlapped. Also I am uncertain whether this
way of passing the parameters is good
WIP: Multiple cuda streams
We have to seperate the stateful allocation and stateless
transmission for the vllm implementation (worker multiple processes)
- Add initial support for context caching:
    1. Support the endpoint
    2. Introduce another sequence type is_fixed. Now that is_fixed is also considered as is_finished
    3. Note that now the context caching always resides in HBM because the blocks are marked as allocated, and by default, the allocated blocks will not be swapped to any secondary storage.
* Add example disk swap config. Add unit tests for CC with memory tiering

* Layered transfer for DRAM. Transfer in cuda streams

* Fix the missing arg

* Fix context caching online serving

This commit enables layered transmission for DRAM first. Now the
transmission is done in different cuda streams. xformers, flash infer
and flash attention are supported. Optimized transfer for disk
is still pending.

Cherry-pick Yangshen's commit

---------

Co-authored-by: yangshen <yangshen.d@outlook.com>
To be tested. Rebased on v0.6.1.post2
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@Amelia26345
Copy link

Great idea, but it seems there are performance issues at this stage. In my testing, the execution time of the function scheduler.schedule() has increased more than tenfold compared to vllm 0.6.1.post2, even with only prefixcaching enabled. Is there an optimization plan for this issue?

@PanJason
Copy link
Author

Great idea, but it seems there are performance issues at this stage. In my testing, the execution time of the function scheduler.schedule() has increased more than tenfold compared to vllm 0.6.1.post2, even with only prefixcaching enabled. Is there an optimization plan for this issue?

Hi, every glad to see you are interested. Did you run my branch for your test? And did you see an end2end slow down because of scheduling? I have not looked into scheduling logic yet but I guess this is because I put a lot of policy decisions in the block_manager which is invoked by scheduler at multiple places.

@Amelia26345
Copy link

Great idea, but it seems there are performance issues at this stage. In my testing, the execution time of the function scheduler.schedule() has increased more than tenfold compared to vllm 0.6.1.post2, even with only prefixcaching enabled. Is there an optimization plan for this issue?

Hi, every glad to see you are interested. Did you run my branch for your test? And did you see an end2end slow down because of scheduling? I have not looked into scheduling logic yet but I guess this is because I put a lot of policy decisions in the block_manager which is invoked by scheduler at multiple places.

Yes, I run v0.6.1.post2_tiering branch. I separately counted the scheduling time. When enable_caching = True, the scheduling time increases. When enable_memory_tiering = True, both the scheduling time and the execution time of execute_worker increase, indicating that the performance loss caused by scheduling is greater. Do you have any ideas for optimizing scheduling? Perhaps you can refer to some methods in the CacheAttention paper to optimize pipelines.
I would also like to ask, does context caching fix the input content in the video memory for a period of time? Can you briefly explain the usage scenarios of context caching? thk

@PanJason
Copy link
Author

PanJason commented Oct 2, 2024

Great idea, but it seems there are performance issues at this stage. In my testing, the execution time of the function scheduler.schedule() has increased more than tenfold compared to vllm 0.6.1.post2, even with only prefixcaching enabled. Is there an optimization plan for this issue?

Hi, every glad to see you are interested. Did you run my branch for your test? And did you see an end2end slow down because of scheduling? I have not looked into scheduling logic yet but I guess this is because I put a lot of policy decisions in the block_manager which is invoked by scheduler at multiple places.

Yes, I run v0.6.1.post2_tiering branch. I separately counted the scheduling time. When enable_caching = True, the scheduling time increases. When enable_memory_tiering = True, both the scheduling time and the execution time of execute_worker increase, indicating that the performance loss caused by scheduling is greater. Do you have any ideas for optimizing scheduling?

Yeah, I know one reason that can explain why the latency of scheduling increases in both cases. I enabled one small optimization when checking whether the seq can be allocated or not. The original vLLM simply compares the required blocks of a seq with the available free blocks to decide whether this seq can be allocated or not. My code performs a prefix matching to get the exact required blocks so the decision is less conservative. I did not cache the hash value of blocks now so we did one additional hash compute whose cost is O(L^2). I am planning to cache it later.

I guess the execute time increases because of some additional logic of check whether the layered transfer is enabled or not, I am not sure this will lead to huge increase. How much was the increase you saw at this place>

Perhaps you can refer to some methods in the CacheAttention paper to optimize pipelines. I would also like to ask, does context caching fix the input content in the video memory for a period of time? Can you briefly explain the usage scenarios of context caching? thk

Right now, vLLM does not have a good support for prefix caching in multi modality because the video part is filled with the default placeholder and is only updated when the video model finishes execution. We will support the prefix caching for multi modality later this year.

Right now, the context caching only supports caching a prompt for ttl time. The code guarantees for this time, the context caching will exist in the system. We will consider add a file interface later so the context caching can be generated from uploaded files.

Sorry for my late reply. I hope this helps you,

@Amelia26345
Copy link

Great idea, but it seems there are performance issues at this stage. In my testing, the execution time of the function scheduler.schedule() has increased more than tenfold compared to vllm 0.6.1.post2, even with only prefixcaching enabled. Is there an optimization plan for this issue?

Hi, every glad to see you are interested. Did you run my branch for your test? And did you see an end2end slow down because of scheduling? I have not looked into scheduling logic yet but I guess this is because I put a lot of policy decisions in the block_manager which is invoked by scheduler at multiple places.

Yes, I run v0.6.1.post2_tiering branch. I separately counted the scheduling time. When enable_caching = True, the scheduling time increases. When enable_memory_tiering = True, both the scheduling time and the execution time of execute_worker increase, indicating that the performance loss caused by scheduling is greater. Do you have any ideas for optimizing scheduling?

Yeah, I know one reason that can explain why the latency of scheduling increases in both cases. I enabled one small optimization when checking whether the seq can be allocated or not. The original vLLM simply compares the required blocks of a seq with the available free blocks to decide whether this seq can be allocated or not. My code performs a prefix matching to get the exact required blocks so the decision is less conservative. I did not cache the hash value of blocks now so we did one additional hash compute whose cost is O(L^2). I am planning to cache it later.

I guess the execute time increases because of some additional logic of check whether the layered transfer is enabled or not, I am not sure this will lead to huge increase. How much was the increase you saw at this place>

Perhaps you can refer to some methods in the CacheAttention paper to optimize pipelines. I would also like to ask, does context caching fix the input content in the video memory for a period of time? Can you briefly explain the usage scenarios of context caching? thk

Right now, vLLM does not have a good support for prefix caching in multi modality because the video part is filled with the default placeholder and is only updated when the video model finishes execution. We will support the prefix caching for multi modality later this year.

Right now, the context caching only supports caching a prompt for ttl time. The code guarantees for this time, the context caching will exist in the system. We will consider add a file interface later so the context caching can be generated from uploaded files.

Sorry for my late reply. I hope this helps you,

Thanks for reply, the time of execute_worker (data movement) takes up about 20% of the entire inference time (GPU: H100, the proportion should be lower for other types of GPU) , I hope this data can help you

@PanJason
Copy link
Author

Hello, Thanks for the data! I am actually working on it now to reduce this part. I need to get the access to the H100 to test it.

@ClarkChin08
Copy link

ClarkChin08 commented Oct 29, 2024

Hello, Thanks for the data! I am actually working on it now to reduce this part. I need to get the access to the H100 to test it.

Hi, I think the idea is quite good and I have tested this PR on A100 GPU, but seems has issue like below:
image
Do you know why this happened? Thanks!

I just use 'pip install -e .' to build the vllm on my local machine and use python benchmarks/benchmark_context_caching.py for the test.

@PanJason
Copy link
Author

Hello, Thanks for the data! I am actually working on it now to reduce this part. I need to get the access to the H100 to test it.

Hi, I think the idea is quite good and I have tested this PR on A100 GPU, but seems has issue like below: image Do you know why this happened? Thanks!

I just use 'pip install -e .' to build the vllm on my local machine and use python benchmarks/benchmark_context_caching.py for the test.

Let me have a look what is happening. It is likely I changed some APIs internally.

@zeroorhero
Copy link

Hi! I have an idea. Can we support a key-value database similar to valkey (Redis over RDMA)? Among them, the key is the hash value of the token. In this way, the data in this database can be completely shared by multiple vllm instances.

@qiuyuleng1
Copy link

Hello, Thanks for the data! I am actually working on it now to reduce this part. I need to get the access to the H100 to test it.

Hi, I think the idea is quite good and I have tested this PR on A100 GPU, but seems has issue like below: image Do you know why this happened? Thanks!

I just use 'pip install -e .' to build the vllm on my local machine and use python benchmarks/benchmark_context_caching.py for the test.

I met the same error on V100. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants