-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3.2 Vision Model: Guides and Issues #8826
Comments
What is the ETA for adding support for multiple images? For a project I'm working on, we're trying to see if it's viable to use vLLM for Llama 3.2 in the next few days. |
What is optimal way to use H100s? I'd guess 8 are not required. I can get 4*H100's to start up like this:
But it fails when used like this. Just for a test TEXT only query, no images. Seems very short queries work, but around 10k input fails.
|
How can I run the vllm docker image against the meta provided checkpoint (not the huggingface one)? |
Same here. If we can know the ETA for adding support for multiple images and interleaving images, it will be truly appreciated. We also have a project that depends on this kind of SOTA VLMs. Thank you so much for your hard work and great contributions! |
You can add the logic of loading meta provided checkpoint here. The main work will be changing some names of the parameter. And you still need huggingface's configs. |
I got a HF checkpoint thatnks to some EU friends, managed to serve the model half of today, but now it's dying
|
What is the general memory requirement for the model. Seems that 90b model is not able to run on an 8x40GB A100 machine. |
pif... imagine I tried to run it on 8 A30 :( |
If you run into OOM issues, please reduce |
Thank you all for trying out Llama 3.2 vision model on vLLM! As you may already know, multimodal Llama 3.2 is quite different from other LlaVA-style VLMs that we currently support on vLLM as it involves cross-attention layers in the language model, and below are the next steps and their priorities for features and optimizations related to this new architecture. P0
P1
P2
|
Any thoughts on my issue or should I post real separate issue?
|
@pseudotensor fix is here #8870. |
It's a good first attempt but as it currently is it's basically unusable. Without caching the cross attention projections it is insanely slow, probably about 1/20th the speed it should be. |
Can you provide more information for us to reproduce your problem? e.g., your benchmark script & your data. |
I don't think any of the implementations currently have the cross attention projection caches? But for inference, it looks like the outputs of the cross attention kv projections for attending to the image can be cached after the first token and reused for all subsequent inference steps without being recalculated each time (the image doesn't change). vllm/vllm/model_executor/models/mllama.py Lines 674 to 707 in bd429f2
It's possible I am misunderstanding the code, but it looks like it repeats them each time even though the output should be the same? |
I suggest you take a look at the actual attention module implementation here to understand how attention with KV cache works. Lines 15 to 27 in bd429f2
|
OK -- I misunderstood. I think my previous speed issues might have been a skill issue on my part and I dug through the other issues to try to figure out which command line arguments I was missing. Running CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --gpu-memory-utilization 0.95 --model=meta-llama/Llama-3.2-11B-Vision-Instruct --tokenizer=meta-llama/Llama-3.2-11B-Vision-Instruct --download-dir="/home/user/storage/hf_cache" --dtype=bfloat16 --device=cuda --host=127.0.0.1 --port=30000 --max-model-len=8192 --quantization="fp8" --enforce_eager --max_num_seqs=8 On 3x 3090s is giving me:
|
Found another weird thing with vllm and LLaMA 3.2 11b, or possibly another failure of myself to read the docs. If I have |
May I ask you to share vLLM argument for the model deployment that you used? |
It's above.
|
I'm still having other issues too. vllm will hang randomly sometimes without errors and stop responding to requests after a few hours. :( |
Perfectly worked for me. Thank you so much! |
Unfortunately performance for me degrades over time and I don't know why. After about an hour of giving it continuous requests throughput falls to about half the original rate. It's very reproducible at least. edit: After going through the logs, I wonder if it's another issue with the HTTP server.
On the slow gpu, it stops receiving new requests ahead of time and |
Ok, I give up. The one server (which is running on localhost) stops processing requests randomly. It results in long stretches of idle time where no work is being done, rendering the openai server unusable. |
Just to report that I have been serving 90b from HF checkpoint without issues (besides when trying to use instructor on top that killed the server). |
I tried offline inference mode with batch size 8 on a 3090, again it gets hung randomly and stops processing batches. On ctrl+c it shows it is stuck in
vllm was instantiated like: model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
llm = LLM(
model=model_name,
max_model_len=8192, # Adjust as needed
max_num_seqs=8, # Adjust as needed
enforce_eager=True,
quantization="fp8",
gpu_memory_utilization=0.98,
) It was also much slower (~7s/caption) versus the openai server. |
I am getting the following error:
Full logs: https://gist.github.com/samos123/43db2724d7ac1bba44a379ce300b156b I reduced memory requirements and afterwards got a new error: https://gist.github.com/samos123/ee858936496e1d314785719f9287230a Update: I only get the issue when I had |
@samos123 Yea I'm almost sure fp8 for kv cache isn't working with cross attention yet. FWIW, the command below works for me on 1xH100
|
I did get it to run stably overnight at 5 seconds/caption on 3090s at The problem I'm seeing is that performance degrades very quickly because vllm stops doing batches together. I'm not sure how that is determined (based on kv-cache occupancy?). So it will start out going approximately double the speed then degrade in about 20-30 minutes. At startup:
After some time:
Setting edit: I tried the same run command with L40s and it reliably does batch sizes of 8 at around 0.9 seconds/caption. So it seems to be an issue with lower VRAM cards. The FastAPI does indeed have another problem, if you send too many requests at once they all seem to timeout. I was only able to get the L40s to uses a consistent batch size of 8 by carefully trickling in requests, which seems less than ideal. Maybe the FastAPI server needs more threads/processes to handle incoming requests efficiently. |
@AmericanPresidentJimmyCarter Yes, the scheduling is based on kv-cache occupancy. |
That makes sense. It seems like a lot of the kv cache may stick around too long to be useful, as when you first start up the server it runs almost twice as fast (and does 3-4x simultaneous requests instead of 1). It would be nice if there was a way to occasionally purge the kv cache like an endpoint I could poll or an argument. |
For anyone perplexed about this: you might need to tweak edit: Go here: #2492 |
Just FYI, on a 8xH100 node I am getting:
|
How can i deploy this to aws it is for an experiment and i am really new to AI ops, can anyone assist me please? |
whats the vram usage? |
~60-70GB |
https://docs.vllm.ai/en/stable/models/vlm.html Am I correct the implication of the above guide is that I can do online inference ONLY with the /v1/chat/completions endpoint and ONLY with a url for the image? Is there another way to supply an image to the model server besides a web url? |
Btw I am having no problems serving https://huggingface.co/neuralmagic/Llama-3.2-90B-Vision-Instruct-FP8-dynamic on 2 A100s. Text inference and text+single image inference both work great but is very slow. Looking forward to the optimized implementation :) |
base64 |
Indeed,
Here is a guide from OpenAI
https://platform.openai.com/docs/guides/vision
…On Sat, Oct 5, 2024, 00:38 39th president of the United States, probably < ***@***.***> wrote:
https://docs.vllm.ai/en/stable/models/vlm.html
Am I correct the implication of the above guide is that I can do online
inference ONLY with the /v1/chat/completions endpoint and ONLY with a url
for the image? Is there another way to supply an image to the model server
besides a web url?
base64
—
Reply to this email directly, view it on GitHub
<#8826 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQGGY3NPGVNR6EQVZNYR553ZZ4KHJAVCNFSM6AAAAABO3TDIGWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOJUG42TIOBQGI>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
i m using CUDA_VISIBLE_DEVICES=0,1 python3 -m vllm.entrypoints.openai.api_server i m using 2 nvidia GPU when sending 1 request woking perfect. but when sending parellel 2 request then only 1 is processing and other is blocking till first process not complete. How can i process parellel request. Can anyone help on it. i m new in it |
We are working in this pr. #9095 |
I don't think that was in reply to javed828's comment: that was in reply to the comment I made about 2 weeks ago at the beginning of the thread. |
@tcapelle yes i m hitting parellel and i m using nvidia T4 |
The GPU blocks count is set to be very small compared to the VRAM size. Does anyone have the same issue? |
It's pretty easy to hack in support for Edit: commit here |
u can try to reduce your max_seq_num for setting parms . u can discover GPU blocks can increase when u decrease max seq num. i think a lot of memory be preserved for mm_tokens in each seq.
|
How can we optimize llama-3.2-11b to run on 4 T4 GPUs, Avg tokens per second is 2 |
Running the server (using the vLLM CLI or our docker image):
vllm serve meta-llama/Llama-3.2-11B-Vision-Instruct --enforce-eager --max-num-seqs 16
vllm serve meta-llama/Llama-3.2-90B-Vision-Instruct --enforce-eager --max-num-seqs 32 --tensor-parallel-size 8
Currently:
Please see the next steps for better supporting this model on vLLM.
cc @heheda12345 @ywang96
The text was updated successfully, but these errors were encountered: