-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Open
Labels
RFCkeep-openPrevents stale label being appliedPrevents stale label being appliedmulti-modalityRelated to multi-modality (#4194)Related to multi-modality (#4194)
Description
Motivation.
Now that fast processors have been merged into HF Transformers, we can call HF processors with device="cuda" to run them on GPU. This has been shown to improve the processing speed up to one order of magnitude (though that requires a large batch size which doesn't really happen in practice since we call the HF processor once per prompt).
Benchmarks: huggingface/transformers#39591
Proposed Change.
We can support this via mm_processor_kwargs, e.g.:
# Use CUDA by default
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={"use_fast": True, "device": "cuda"})
# Use CUDA for one request - technically possible but requires memory profiling
# to assume GPU preprocessing usage even if no requests actually do this,
# so likely not support this niche case to save code complexity
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={"use_fast": True})
llm.generate({
"prompt": formatted_prompt,
"multi_modal_data": {"image": image},
"mm_processor_kwargs": {"device": "cuda"},
})
# -- or --
llm.chat(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
mm_processor_kwargs={"device": "cuda"},
)Issues to solve:
- We need to accurately determine whether the HF processor accepts
deviceargument (or any other argument in general) ([Misc] Automatically resolve HF processor init kwargs #22005) - How to accommodate this in memory profiling so that users cannot cause OOM during inference? ([Core] Enable HF processing on GPU #22070)
- Release benchmarks on vLLM
Follow-ups:
- After processing, some (but not all!) of the inputs will be on GPU but still need to be transferred between API and EngineCore processes. Can we avoid serializing the inputs which wastefully moves them back to CPU?
- To increase the batch size, we can accumulate multiple prompts before running the HF processor on all of the inputs together. This is expected to improve throughput at the cost of latency.
Feedback Period.
No response
CC List.
@hmellor @Isotr0py @ywang96 @njhill
Alternatives
No response
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
noooop, hmellor, Isotr0py, yonigozlan, wwl2755 and 2 more
Metadata
Metadata
Assignees
Labels
RFCkeep-openPrevents stale label being appliedPrevents stale label being appliedmulti-modalityRelated to multi-modality (#4194)Related to multi-modality (#4194)
Type
Projects
Status
In Progress