-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen #25520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen #25520
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly identifies and fixes a memory corruption issue by introducing a separate, zero-initialized workspace buffer for trtllm-gen FlashInfer kernels. This prevents state corruption between different kernel types. My review focuses on improving the implementation of this fix by addressing a potential race condition. I've suggested making the initialization of the new global workspace buffer thread-safe using a lock to prevent issues in multi-threaded environments.
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable. It is fine for prefill and decode to share workspace?
|
@mgoin yes, it seems fine. They both use the workspace in a similar way. It's implied in the FlashInfer tests since all cases use the same global workspace buffer, both prefill and decode. |
|
See also this PR comment, that it is expected behaviour that the buffer is re-used between tests. |
| def _get_trtllm_gen_workspace_buffer(): | ||
| global trtllm_gen_workspace_buffer | ||
| if trtllm_gen_workspace_buffer is None: | ||
| trtllm_gen_workspace_buffer = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the FI PR it says trtllm-gen requires zero-init buffer but flashinfer doesn't need it
No. trtllm-gen kernel and fi kernel should re-use individual workspace as fi kernel does not require zero-init workspace.
Can we make it always a zero init buffer? I am only concerned if we do run both flavors of kernels for perf, we'd end up occupying double the workspace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose FI not cleaning up the buffer is the concern here, and we want to separate the two?
…25520) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…llm-project#25520) Signed-off-by: gaojc <1055866782@qq.com>
…llm-project#25520) Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…llm-project#25520) Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
I've been getting some rare illegal memory accesses when developing using trtllm-gen flashinfer kernels.
I believe the main issue comes down to the fact that the trtllm-gen and non-trtllm-gen kernels need separate workspaces. Here is a FlashInfer PR (merged) that updates the tests to avoid this issue.
flashinfer-ai/flashinfer#1643
Detailed summary
Flashinfer's wrapper-based kernels (both prefill and decode) use the workspace buffer as a scratch-space for storing intermediate results (such as split-k accumulation data). They do not require it to be zero-initialized and might not clean it up after writing data into it.
On the other hand, trtllm-gen kernels require their workspace buffer to be zero-initialized and will clean up after using it, to maintain the state invariance.
vLLM currently uses the same workspace buffer for all four (trtllm/prev, prefill/decode) combinations. This leads to rare illegal accesses when one of them corrupts the state for the other. This PR adds a dedicated, zero-initialized buffer for the trtllm-gen kernels. When using this change, I stress-tested my development deployment and do not see any more crashes.