-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model][Jamba] Mamba cache single buffer #6739
[Model][Jamba] Mamba cache single buffer #6739
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
/ready |
dc9bf07
to
d57ccb6
Compare
PR is ready, CI failures are not related to this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just starting to read through now. At a high level, the approach makes sense to me. Do you anticipate any cases where we'll end up shuffling a lot of data in and out of the first N slots?
And do you have any end-to-end performance numbers you can share?
You mentioned an added test for parallel sampling, but it's not present in this PR. Did you mean to remove it? I noticed that the added test was there previously
Thank you for the review! Sorry for the long delay. Most shuffling occurs during the transition from prefill steps to decoding steps. However, shuffling between sequential decoding steps ( which populate the majority of the steps distribution under a regular load ) doesn't happen very often since the cache is already in place ( previous implementation copied the mamba cache from buffer to buffer in each and every decode step). And regarding end-to-end perf - so yeah, we benchmark prefill and decoding forward passes independently. We've seen 1-2 ms speed up in decoding, and no change in prefill steps. However, the major purpose of this PR is to reduce the memory usage. Red line is the previous implementation, blue line is this PR implementation. RE - Parallel sampling test. Yeah, I've intended to add it but the tiny Jamba model we use for unittest behaves differently on different devices. So I've left it out for now until we have a trained tiny model for tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments in-line. Generally, I think the approach makes sense and don't see any specific problems, but I think we should get somebody working on multi-step scheduling to review in case any conflicts might arise there. @alexm-neuralmagic could you look into that and suggest other reviewers as well?
I think the functions that manage the mamba cache might be better organized if they were factored out and encapsulated in their own class. I was thinking we could try to make it behave similarly to the BlockManager in terms of interface. Goal would be to incrementally make the mamba cache fit into vLLM's native systems. Doesn't have to be in this PR but curious to hear your thoughts on this.
One last question: A lot of this would be simpler if the two mamba cache update functions took a list of indices rather than requiring contiguous tensors. Have you looked into this at all? To me it looks like it wouldn't be too technically difficult to do, but would require a pair of PRs on https://github.com/Dao-AILab/causal-conv1d and https://github.com/state-spaces/mamba. Might be worth it just to avoid the state management.
Thank you for the review! @alexm-neuralmagic would love to hear your opinion. |
@mzusman @tlrmchlsmth Did a quick pass over the PR and I see that the changes are inside the forward() function of the model itself. The multi-step logic is "above" this function, so I don't think it should interfere with the changes here. Btw, nice optimization! |
@mzusman FYI I am working on modifying the kernels to take a tensor of indices for the batch coordinates. I think this branch gives us the interface we'd need to avoid all of the state copying for causal_conv1d_update: Going to try to do the same thing to |
That's really great! Cache management will be easier to handle. That's right, landing this PR is quite urgent for us at the moment and does not block future improvements. I think it would be better to split those improvements/PRs. |
FYI I just restarted the failed jobs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to add unit tests for the utils that maintain the cache? Seems like they're complicated enough to want additional testing. Beyond that, LGTM if green
I think it makes sense to add unittests that test the cache management utils, RE CI - I'll rebase, maybe it will help, failures doesn't seems to relate to this PR. |
* WIP - working on swaping indices * WIP * Save changes * Orginize indices during assigment, working and passing tests! * Add TODOs * Remove diff * Format * Remove TODOs * Remove unused code * Cleanup * Cleanup * Cleanup the redundant 10 blocks * Small changes * Simplify code and add comments * Renaming and simplify * Remove return * Clean up * Cleanup * Renaming * Another clean up * Clean up * Clean up and simplify more * Add n > 1 test * Format * cosmetics * Add functionality to find first free * Raise exception if could not find spot * Typos * Add 2 slots as precaution --------- Co-authored-by: Mor Zusman <morz@ai21.com>
This reverts commit 381c2aa.
This reverts commit f1e792d.
This reverts commit bda9876.
2f5293b
to
3eeeeb7
Compare
Going to merge this one, and then try to simplify with updated kernels :) Thanks! |
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com> Signed-off-by: Alvant <alvasian@yandex.ru>
By carefully allocating the Mamba cache at the first "n" slots in the mamba cache before FWD pass ,
We can now remove the redundant CG Mamba buffer.
This PR saves memory, simplifies the Jamba inner state management code and accelerates latency (by removing redundant data copies).
This PR is also applicable to #6484 @tlrmchlsmth .