-
Notifications
You must be signed in to change notification settings - Fork 14k
Description
Proposed Optimizations
i) Modify load_tensors method in clip.cpp to assign/allocate vision model weights to the CPU and stream it to the device only at runtime. This provides memory savings == the total weight of the vision tensors and do not make the image encoding terribly slow as is seen with --no-mmproj-offload which keeps the computation also on the CPU
ii) During multi-modal inference, defer the LLM model and context initialization from mtmd-cli, and rather invoke it JIT from mtmd-helper after encoding(pre-encoding) the IMAGE chunk is done and CLIP models' smemory resources are freed. This preserves downstream decode order but frees vision VRAM earlier.
Thus peak VRAM reduces: SUM(clip encode, text decode) -> MAX(clip encode, text decode)
iii) Currently if we want to run VQA inference for very large images with Flash-Attention and GPU support on, e.g. command-line:- llama-mtmd-cli.exe -m Qwen2.5-VL-7B-Instruct-q8_0.gguf --mmproj Qwen2.5-VL-7B-Instruct-mmproj-bf16.gguf -p <PROMPT> --image <4K_IMAGE> -ngl 99 -c 12000 --image-max-tokens 8192 -n 100, I encounter runtime assertion failure: cpy.cu:359: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX) failed. Basically we start exceeding 2GB ggml_cpy limit.
The proposed solution here is to implement tiled flash-attention to avoid 2GB/INT_MAX ggml_cuda_cpy for inputs with large FA masks. Tile the Q tensor into smaller chunks, processes each with flash attention, concatenate results and apply output projection.
My Implementation
I have implemented each of these proposed code changes and you can find my changes here: https://github.com/deepshnv/llama.cpp/tree/deepshnv/mtmd_vram_opti
@ggerganov @slaren, I would request you to please go through these implementations. I will leave it to you and the other llama.cpp devs to decide if and how much of my code to use.
The draft PR can be found here: #17802
Results showing reduced VRAM usage while maintaining the performance metrics
All these numbers were gotten on the same PC with a RTX 5080 GPU
Command-line used: llama-mtmd-cli.exe -m Qwen2.5-VL-7B-Instruct-q8_0.gguf --mmproj Qwen2.5-VL-7B-Instruct-mmproj-bf16.gguf -p <PROMPT> --image <IMAGE> // rest of the cmdline flags
AI usage: The entire implementation is mostly my own code changes with minimal help from AI(cursor) limited to code structuring, navigation and commenting. Even then I did my due diligence and manual review before accepting any change.