diff --git a/common/arg.cpp b/common/arg.cpp index 9f3c8a9754..3b40edfde2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1803,6 +1803,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.mmproj_use_gpu = false; } ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD")); + add_opt(common_arg({ "--clip-reduced-vram" }, + "offload clip weights to CPU and stream at runtime to device (default: false)", + [](common_params & params) { + params.clip_reduced_vram = true; + } + ).set_env("LLAMA_ARG_CLIP_REDUCED_VRAM")); add_opt(common_arg( {"--image", "--audio"}, "FILE", "path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n", diff --git a/common/common.cpp b/common/common.cpp index 0d7fd9a937..8b8a90367e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1184,6 +1184,181 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } +// Initialize llama_context and related state using an already-loaded llama_model +// - Does not take ownership of the model passed in; caller manages its lifetime +// - Mirrors common_init_from_params except for model creation/free +struct common_init_result common_init_from_existing_model(common_params & params, struct llama_model * model_in) { + common_init_result iparams; + + if (model_in == NULL) { + LOG_ERR("%s: model '%s' is NULL\n", __func__, params.model.path.c_str()); + return iparams; + } + + common_init_sampler_from_model(model_in, params.sampling); + + const llama_vocab * vocab = llama_model_get_vocab(model_in); + + auto cparams = common_context_params_to_llama(params); + + llama_context * lctx = llama_init_from_model(model_in, cparams); + if (lctx == NULL) { + LOG_ERR("%s: failed to create context with existing model '%s'\n", __func__, params.model.path.c_str()); + return iparams; + } + + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { + LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); + params.ctx_shift = false; + } + + if (!params.control_vectors.empty()) { + if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model_in); + + const auto cvec = common_control_vector_load(params.control_vectors); + if (cvec.n_embd == -1) { + llama_free(lctx); + + return iparams; + } + + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); + if (err) { + llama_free(lctx); + + return iparams; + } + } + + if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) { + bool ok = true; + + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); + ok = false; + } + + bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; + bool has_rerank_prompt = llama_model_chat_template(model_in, "rerank") != NULL; + + if (!has_eos && !has_sep && !has_rerank_prompt) { + LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__); + ok = false; + } else if (!has_eos) { + LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); + } + + if (!ok) { + llama_free(lctx); + + return iparams; + } + } + + // load and optionally apply lora adapters + for (auto & la : params.lora_adapters) { + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model_in, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + llama_free(lctx); + return iparams; + } + + char buf[1024]; + la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; + iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + } + + if (!params.lora_init_without_apply) { + common_set_adapter_lora(lctx, params.lora_adapters); + } + + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } + + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); + } + } + + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + + if (params.sampling.penalty_last_n == -1) { + LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.penalty_last_n = llama_n_ctx(lctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); + } + + if (params.warmup) { + LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); + + llama_set_warmup(lctx, true); + + std::vector tmp; + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + + // some models (e.g. T5) don't have a BOS token + if (bos != LLAMA_TOKEN_NULL) { + tmp.push_back(bos); + } + if (eos != LLAMA_TOKEN_NULL) { + tmp.push_back(eos); + } + if (tmp.empty()) { + tmp.push_back(0); + } + + if (llama_model_has_encoder(model_in)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model_in); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } + if (llama_model_has_decoder(model_in)) { + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); + } + llama_memory_clear(llama_get_memory(lctx), true); + llama_synchronize(lctx); + llama_perf_context_reset(lctx); + llama_set_warmup(lctx, false); + } + + iparams.context.reset(lctx); + + return iparams; +} + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. diff --git a/common/common.h b/common/common.h index 2f23d0baa8..460492b1f5 100644 --- a/common/common.h +++ b/common/common.h @@ -537,6 +537,8 @@ struct common_params { bool has_speculative() const { return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); } + + bool clip_reduced_vram = false; // keep weights' of vision models offloaded to CPU and stream at runtime to required device }; // call once at the start of a program if it uses libcommon @@ -657,6 +659,8 @@ struct common_init_result { }; struct common_init_result common_init_from_params(common_params & params); +// Initialize only the context using an already-loaded llama_model* +struct common_init_result common_init_from_existing_model(common_params & params, struct llama_model * model); struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4..eaa8f322cb 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -490,3 +490,4 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { // projector_type clip_get_projector_type(const struct clip_ctx * ctx); +int64_t clip_get_image_encode_timing(const struct clip_ctx * ctx); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d8222d8814..421404da1c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -424,11 +424,15 @@ struct clip_ctx { ggml_backend_t backend = nullptr; ggml_backend_t backend_cpu = nullptr; ggml_backend_buffer_ptr buf; + std::vector additional_buffers; // for tensor overrides int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + bool clip_reduced_vram = false; + int64_t image_encode_timing = 0; + // for debugging bool debug_graph = false; std::vector debug_print_tensors; @@ -476,6 +480,8 @@ struct clip_ctx { sched.reset( ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true) ); + + clip_reduced_vram = ctx_params.clip_reduced_vram; } ~clip_ctx() { @@ -762,6 +768,24 @@ struct clip_graph { inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); } + // Check if tiled flash attention is needed to avoid 2GB/INT_MAX ggml_cpy limit in cpy.cu: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + bool needs_tiled_fa = false; + size_t tiled_fa_q_tile = n_pos; + if (use_window_attn && ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + // Calculate max q_tile that keeps mask under 2GB during cast/pad operations + // am taking worst case mask size = n_kv × qlen × element_bytes (worst case F32 = 4 bytes) + const int64_t n_kv = n_pos; + const size_t mask_element_bytes = sizeof(float); + const int64_t max_qlen = (INT_MAX / (n_kv * (int64_t)mask_element_bytes)) & ~63LL; + if ((int64_t)tiled_fa_q_tile > max_qlen && max_qlen >= 64) { + tiled_fa_q_tile = (size_t)max_qlen; + needs_tiled_fa = true; + LOG_INF("%s: will use tiled FA with q_tile=%zu to keep mask under 2GB/INT_MAX (n_pos=%d)\n", __func__, tiled_fa_q_tile, n_pos); + } + } + + ggml_tensor * window_mask_raw = nullptr; // unprepared mask for tiled FA path + if (use_window_attn) { // handle window attention inputs inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); @@ -772,13 +796,19 @@ struct clip_graph { ggml_set_name(window_mask, "window_mask"); ggml_set_input(window_mask); - // if flash attn is used, we need to pad the mask and cast to f16 + // if tiling is needed, keep the raw mask and handle padding/casting per-tile, else pad and cast the full mask if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { - int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1]; - if (n_pad > 0) { - window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0); + if (needs_tiled_fa) { + // Keep raw mask for per-tile processing + window_mask_raw = window_mask; + } + else { + int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1]; + if (n_pad > 0) { + window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0); + } + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); } - window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); } // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] @@ -829,8 +859,17 @@ struct clip_graph { ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + // Use tiled flash attention when needed (calculated before layer loop) + const bool use_tiled_fa = needs_tiled_fa && !full_attn; + + if (use_tiled_fa) { + cur = build_flash_attn_tiled(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, window_mask_raw, kq_scale, tiled_fa_q_tile, il); + } + else { + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } cb(cur, "attn_out", il); } @@ -2370,6 +2409,82 @@ struct clip_graph { return cur; } + // Tiled flash attention to avoid 2GB/INT_MAX ggml_cuda_cpy: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + // Tiles the Q tensor into smaller chunks, processes each with flash attention, then concatenates results and applies output projection before returning cur + ggml_tensor * build_flash_attn_tiled(ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_mask_raw, float kq_scale, size_t q_tile_size, int il) const + { + const int64_t d_head = q_cur->ne[0]; + const int64_t n_head = q_cur->ne[1]; + const int64_t num_positions = q_cur->ne[2]; + const int64_t batch_size = 1; + + ggml_tensor * k_fa = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + ggml_tensor * v_fa = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + if (k_fa->type == GGML_TYPE_F32) { + k_fa = ggml_cast(ctx0, k_fa, GGML_TYPE_F16); + } + if (v_fa->type == GGML_TYPE_F32) { + v_fa = ggml_cast(ctx0, v_fa, GGML_TYPE_F16); + } + + std::vector tile_results; + for (int64_t q0 = 0; q0 < num_positions; q0 += (int64_t)q_tile_size) { + const int64_t qlen = std::min((int64_t)q_tile_size, num_positions - q0); + + // Create Q tile view matching original layout, then permute for FA + // q_cur shape: [d_head, n_head, n_positions] -> view as [d_head, n_head, qlen, 1] + const size_t q_off = (size_t)q0 * q_cur->nb[2]; + ggml_tensor * q_tile = ggml_view_4d(ctx0, q_cur, + d_head, n_head, qlen, batch_size, + q_cur->nb[1], q_cur->nb[2], 0, + q_cur->view_offs + q_off); + // Permute to FA layout: [d_head, n_head, qlen, B] -> [d_head, qlen, n_head, B] + q_tile = ggml_permute(ctx0, q_tile, 0, 2, 1, 3); + + ggml_tensor * mask_tile = nullptr; + if (kq_mask_raw != nullptr) { + const size_t m_off = (size_t)q0 * kq_mask_raw->nb[1]; + mask_tile = ggml_view_2d(ctx0, kq_mask_raw, + kq_mask_raw->ne[0], qlen, + kq_mask_raw->nb[1], + kq_mask_raw->view_offs + m_off); + + // Pad and cast mask tile for flash attention + const int64_t padded_qlen = GGML_PAD(qlen, GGML_KQ_MASK_PAD); + if (qlen < padded_qlen) { + mask_tile = ggml_pad(ctx0, mask_tile, 0, (int)(padded_qlen - qlen), 0, 0); + } + mask_tile = ggml_cast(ctx0, mask_tile, GGML_TYPE_F16); + } + + ggml_tensor * tile_out = ggml_flash_attn_ext(ctx0, q_tile, k_fa, v_fa, mask_tile, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(tile_out, GGML_PREC_F32); + + // Reshape to 2D for concatenation: [d_head * n_head, qlen] + tile_out = ggml_reshape_2d(ctx0, tile_out, tile_out->ne[0] * tile_out->ne[1], tile_out->ne[2]); + tile_results.push_back(tile_out); + } + + // Concatenate all tiles along sequence dimension + ggml_tensor * cur = tile_results[0]; + for (size_t i = 1; i < tile_results.size(); i++) { + cur = ggml_concat(ctx0, cur, tile_results[i], 1); + } + cur = ggml_cont(ctx0, cur); + + cb(cur, "kqv_out_tiled", il); + + // Apply output projection + if (wo) { + cur = ggml_mul_mat(ctx0, wo, cur); + } + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + return cur; + } + // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 @@ -3263,8 +3378,67 @@ struct clip_model_loader { // alloc memory and offload data ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); - ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); - ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_type_t cpu_buft = ggml_backend_get_default_buffer_type(ctx_clip.backend_cpu); + + if (ctx_clip.clip_reduced_vram) { + LOG_INF("%s: clip_reduced_vram enabled, offloading all clip tensors to CPU\n", __func__); + std::map> tensors_by_buft; + // Determine buffer type for each tensor + for (auto & t : tensors_to_load) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); + tensors_by_buft[cpu_buft].push_back(cur); + } + + // For each buffer type, create a context and allocate tensors + bool first_buffer = true; + for (auto & [buft, tensor_list] : tensors_by_buft) + { + // Create a temporary context for this buffer type + struct ggml_init_params temp_params = { + /*.mem_size =*/tensor_list.size() * ggml_tensor_overhead(), + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, + }; + ggml_context_ptr temp_ctx(ggml_init(temp_params)); + if (!temp_ctx) { + throw std::runtime_error(string_format("%s: failed to create temporary context\n", __func__)); + } + + // Create tensor references in the temporary context + for (auto * orig_tensor : tensor_list) { + ggml_tensor * temp_tensor = ggml_dup_tensor(temp_ctx.get(), orig_tensor); + ggml_set_name(temp_tensor, orig_tensor->name); + } + + // Allocate buffer for this context + auto buffer = ggml_backend_buffer_ptr(ggml_backend_alloc_ctx_tensors_from_buft(temp_ctx.get(), buft)); + if (!buffer) { + throw std::runtime_error(string_format("%s: failed to allocate buffer for %s\n", __func__, ggml_backend_buft_name(buft))); + } + ggml_backend_buffer_set_usage(buffer.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + // Copy tensor pointers to original tensors + for (auto * orig_tensor : tensor_list) { + ggml_tensor * temp_tensor = ggml_get_tensor(temp_ctx.get(), orig_tensor->name); + orig_tensor->buffer = temp_tensor->buffer; + orig_tensor->data = temp_tensor->data; + } + + // transfer ownership of the buffer to ctx_clip so it lives as long as the model + if (first_buffer) { + ctx_clip.buf = std::move(buffer); + first_buffer = false; + } else { + ctx_clip.additional_buffers.push_back(std::move(buffer)); + } + } + } + else { + // Baseline: no overrides are needed + ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + for (auto & t : tensors_to_load) { ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); const size_t offset = tensor_offset[t->name]; @@ -4617,6 +4791,7 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 } bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { + int64_t t0 = ggml_time_ms(); const clip_image_f32_batch & imgs = *imgs_c_ptr; int batch_size = imgs.entries.size(); @@ -4628,9 +4803,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // build the inference graph ctx->debug_print_tensors.clear(); - ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); - ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); + + // if clip_reduced_vram enabled: use a temporary scheduler to free VRAM right after encode + ggml_backend_sched_t sched_to_use = nullptr; + ggml_backend_sched_ptr sched_local; + if (ctx->clip_reduced_vram) { + sched_local.reset(ggml_backend_sched_new(ctx->backend_ptrs.data(), ctx->backend_buft.data(), (int) ctx->backend_ptrs.size(), ctx->max_nodes, + /*parallel*/ false, /*op_offload*/ true)); + ggml_backend_sched_reset(sched_local.get()); + sched_to_use = sched_local.get(); + } + else { + ggml_backend_sched_reset(ctx->sched.get()); + sched_to_use = ctx->sched.get(); + } + + ggml_backend_sched_alloc_graph(sched_to_use, gf); // set inputs const auto & model = ctx->model; @@ -4965,7 +5154,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); + auto status = ggml_backend_sched_graph_compute(sched_to_use, gf); if (status != GGML_STATUS_SUCCESS) { LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status); return false; @@ -4997,6 +5186,17 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + // if a temporary scheduler was used, freeing it here releases its compute buffers/VRAM + if (ctx->clip_reduced_vram) { + // ensure all device work is finished before destroying the temporary scheduler + ggml_backend_sched_synchronize(sched_to_use); + // explicit scope reset to free underlying resources now + sched_local.reset(); + // synchronize CPU backend for completeness, is it required? + ggml_backend_synchronize(ctx->backend_cpu); + } + ctx->image_encode_timing = ggml_time_ms() - t0; + return true; } @@ -5116,3 +5316,7 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel batch->entries.push_back(clip_image_f32_ptr(audio)); batch->is_audio = true; } + +int64_t clip_get_image_encode_timing(const struct clip_ctx * ctx) { + return ctx->image_encode_timing; +} diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index c1442afe6b..65307b63cd 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -34,6 +34,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + bool clip_reduced_vram; // offload clip weights to CPU and stream at runtime to backend device }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 6679de309b..c8bd47ed07 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -87,15 +87,114 @@ struct mtmd_cli_context { int n_threads = 1; llama_pos n_past = 0; + common_params saved_params; // keep a copy for JIT LLM init + int32_t saved_n_gpu_layers = 0; + + mtmd_cli_context(common_params & params) { + if (params.clip_reduced_vram) + { + saved_params = params; + saved_n_gpu_layers = params.n_gpu_layers; // save original GPU layers + n_threads = params.cpuparams.n_threads; + n_batch = params.n_batch; + + // Load LLM model with n_gpu_layers=0 (CPU only) initially, will move to n_gpu_layers JIT later + params.n_gpu_layers = 0; + LOG_INF("%s: clip_reduced_vram enabled - loading LLM model with n_gpu_layers=0\n", __func__); + + // Defer LLM context init; still need model to build chat templates + auto mparams = common_model_params_to_llama(params); + model = llama_model_load_from_file(params.model.path.c_str(), mparams); + if (!model) { + exit(1); + } + tmpls = common_chat_templates_init(model, params.chat_template); + use_jinja = params.use_jinja; + chat_history.clear(); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); + + init_vision_context(params); + mtmd_set_llm_init_callback( + ctx_vision.get(), + [](void * user_data) { + auto * self = static_cast(user_data); + self->init_llm_context(self->saved_params); + // pass the fresh lctx back into mtmd for helper functions + mtmd_set_llm_context(self->ctx_vision.get(), self->lctx); + }, + this); + } + else + { + // Baseline + llama_init = common_init_from_params(params); + model = llama_init.model.get(); + lctx = llama_init.context.get(); + vocab = llama_model_get_vocab(model); + smpl = common_sampler_init(model, params.sampling); + n_threads = params.cpuparams.n_threads; + batch = llama_batch_init(1, 0, 1); // batch for next token generation + n_batch = params.n_batch; + + if (!model || !lctx) { + exit(1); + } + + if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) { + LOG_ERR("Model does not have chat template.\n"); + LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n"); + LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n"); + LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n"); + exit(1); + } + + tmpls = common_chat_templates_init(model, params.chat_template); + use_jinja = params.use_jinja; + chat_history.clear(); + LOG_INF("%s: chat template example:\n%s\n", __func__, + common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); + + init_vision_context(params); + + // load antiprompt tokens for legacy templates + if (params.chat_template == "vicuna") { + antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true); + } else if (params.chat_template == "deepseek") { + antiprompt_tokens = common_tokenize(lctx, "###", false, true); + } + } + } + + void init_llm_context(common_params& params) + { + // Free the CPU-loaded model + if (model) { + llama_model_free(model); + model = nullptr; + } + + // Reload model with original n_gpu_layers (GPU offloading) + params.n_gpu_layers = saved_n_gpu_layers; + LOG_INF("%s: reloading LLM model JIT with n_gpu_layers=%d\n", __func__, saved_n_gpu_layers); - mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { - model = llama_init.model.get(); - lctx = llama_init.context.get(); - vocab = llama_model_get_vocab(model); - smpl = common_sampler_init(model, params.sampling); - n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(1, 0, 1); // batch for next token generation - n_batch = params.n_batch; + auto mparams = common_model_params_to_llama(params); + model = llama_model_load_from_file(params.model.path.c_str(), mparams); + if (!model) { + LOG_ERR("Failed to reload LLM model JIT\n"); + exit(1); + } + + // Create context from the GPU-loaded model + auto init2 = common_init_from_existing_model(params, model); + llama_init.model.reset(model); + llama_init.context = std::move(init2.context); + + // refresh raw pointers + model = llama_init.model.get(); + lctx = llama_init.context.get(); + vocab = llama_model_get_vocab(model); + smpl = common_sampler_init(model, params.sampling); + batch = llama_batch_init(1, 0, 1); // batch for next token generation if (!model || !lctx) { exit(1); @@ -108,20 +207,18 @@ struct mtmd_cli_context { LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n"); exit(1); } - - tmpls = common_chat_templates_init(model, params.chat_template); - use_jinja = params.use_jinja; - chat_history.clear(); - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); - - init_vision_context(params); - + // load antiprompt tokens for legacy templates if (params.chat_template == "vicuna") { antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true); } else if (params.chat_template == "deepseek") { antiprompt_tokens = common_tokenize(lctx, "###", false, true); } + + if (!ctx_vision.get()) { + LOG_ERR("Failed to load vision model from %s\n", params.mmproj.path); + exit(1); + } } ~mtmd_cli_context() { @@ -132,12 +229,13 @@ struct mtmd_cli_context { void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params.mmproj_use_gpu; - mparams.print_timings = true; - mparams.n_threads = params.cpuparams.n_threads; - mparams.flash_attn_type = params.flash_attn_type; - mparams.image_min_tokens = params.image_min_tokens; - mparams.image_max_tokens = params.image_max_tokens; + mparams.use_gpu = params.mmproj_use_gpu; + mparams.print_timings = true; + mparams.n_threads = params.cpuparams.n_threads; + mparams.flash_attn_type = params.flash_attn_type; + mparams.image_min_tokens = params.image_min_tokens; + mparams.image_max_tokens = params.image_max_tokens; + mparams.clip_reduced_vram = params.clip_reduced_vram; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index f0891bba30..da009ed46a 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -32,6 +32,28 @@ #define STB_IMAGE_IMPLEMENTATION #include "stb/stb_image.h" + +// internal functions for JIT llm init integration +extern "C" { + // Retrieve the JIT-initialized llama context if available (NULL if not set) + struct llama_context * mtmd_get_llm_context(mtmd_context * ctx); + + // Whether pre-encode/JIT-llm flow is enabled + bool mtmd_preencode_enabled(mtmd_context * ctx); + + // Pre-encode the image chunk; returns 0 on success, or encode error + int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks); + + // Invoke the registered JIT LLM init callback if not already invoked + void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx); + + // Query pre-encoded image state and identity + bool mtmd_has_preencoded_image(mtmd_context * ctx); + + // Retrieve the encode timing (ms) for the media chunk's underlying encoder, returns 0 if unavailable + int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk); +} + // // internal logging functions // @@ -333,7 +355,10 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, if (logits_last && is_last_token) { text_batch.logits[text_batch.n_tokens - 1] = true; } - ret = llama_decode(lctx, text_batch); + { + struct llama_context * lctx_eff = mtmd_get_llm_context(ctx) ? mtmd_get_llm_context(ctx) : lctx; + ret = llama_decode(lctx_eff, text_batch); + } if (ret != 0) { LOG_ERR("failed to decode text\n"); llama_batch_free(text_batch); @@ -347,18 +372,30 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, int64_t t0 = ggml_time_ms(); LOG_INF("encoding %s slice...\n", name); - - ret = mtmd_encode_chunk(ctx, chunk); - if (ret != 0) { - LOG_ERR("failed to encode %s slice\n", name); - llama_batch_free(text_batch); - return ret; + // Skip encode if we have pre-encoded the same image (identified by id) + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE && mtmd_has_preencoded_image(ctx)) { + LOG_INF("using pre-encoded image embeddings\n"); + } + else { + ret = mtmd_encode_chunk(ctx, chunk); + if (ret != 0) { + LOG_ERR("failed to encode %s slice\n", name); + llama_batch_free(text_batch); + return ret; + } } - LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (mtmd_preencode_enabled(ctx)) { + LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, mtmd_get_image_encode_timing(ctx, chunk)); + } else { + LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + } float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + { + struct llama_context * lctx_eff = mtmd_get_llm_context(ctx) ? mtmd_get_llm_context(ctx) : lctx; + ret = mtmd_helper_decode_image_chunk(ctx, lctx_eff, chunk, embd, n_past, seq_id, n_batch, new_n_past); + } if (ret != 0) { LOG_ERR("failed to decode %s\n", name); llama_batch_free(text_batch); @@ -386,6 +423,19 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, return 0; } + // When clip_reduced_vram is enabled, first pre-encode the IMAGE chunk and then release CLIP VRAM before any llama_decode runs. + // This preserves downstream decode order but frees vision VRAM earlier. + if (mtmd_preencode_enabled(ctx)) { + LOG_INF("pre-encoding image before any text decode...\n"); + int32_t ret = mtmd_preencode_image(ctx, chunks); + if (ret != 0) { + LOG_ERR("failed to pre-encode image\n"); + return ret; + } + // Invoke JIT LLM initialization callback after CLIP has freed VRAM + mtmd_invoke_llm_init_if_needed(ctx); + } + for (size_t i = 0; i < n_chunks; i++) { bool chunk_logits_last = (i == n_chunks - 1) && logits_last; auto chunk = mtmd_input_chunks_get(chunks, i); diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 5036b92442..25c039dc85 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -85,6 +85,12 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, int32_t n_batch, llama_pos * new_n_past); +// Register a JIT initializer for the LLM context; intended to be invoked after image encode +MTMD_API void mtmd_set_llm_init_callback(mtmd_context * ctx, mtmd_llm_init_cb cb, void * user_data); + +// Provide the initialized llama context back to mtmd for use in other methods +MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_context * lctx); + #ifdef __cplusplus } // extern "C" #endif diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 6690bf3004..835e4003f8 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -110,6 +110,7 @@ mtmd_context_params mtmd_context_params_default() { /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* clip_reduced_vram */ false, }; return params; } @@ -125,6 +126,10 @@ struct mtmd_context { std::string media_marker; const int n_embd_text; + bool clip_reduced_vram = false; + bool has_pre_encoded_image = false; + bool llm_context_initialized = false; + // these are not token, but strings used to mark the beginning and end of image/audio embeddings std::string img_beg; std::string img_end; @@ -153,6 +158,11 @@ struct mtmd_context { // for whisper, we pre-calculate the mel filter bank whisper_preprocessor::whisper_filters w_filters; + // JIT llm init integration when clip_reduced_vram is enabled + mtmd_llm_init_cb llm_init_cb = nullptr; + void * llm_init_user_data = nullptr; + llama_context * llm_lctx = nullptr; // set via mtmd_set_llm_context when clip_reduced_vram is enabled + // TODO @ngxson : add timings mtmd_context(const char * mmproj_fname, @@ -177,6 +187,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* clip_reduced_vram */ ctx_params.clip_reduced_vram, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -186,6 +197,9 @@ struct mtmd_context { throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname)); } + // store for helper flow control + clip_reduced_vram = ctx_params.clip_reduced_vram; + // if both vision and audio mmproj are present, we need to validate their n_embd if (ctx_v && ctx_a) { int n_embd_v = clip_n_mmproj_embd(ctx_v); @@ -400,6 +414,60 @@ void mtmd_free(mtmd_context * ctx) { delete ctx; } +extern "C" MTMD_API void mtmd_set_llm_init_callback(mtmd_context * ctx, mtmd_llm_init_cb cb, void * user_data) { + ctx->llm_init_cb = cb; + ctx->llm_init_user_data = user_data; +} + +extern "C" MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_context * lctx) { + ctx->llm_lctx = lctx; +} + +extern "C" struct llama_context * mtmd_get_llm_context(mtmd_context * ctx) { + return ctx->llm_lctx; +} + +extern "C" bool mtmd_preencode_enabled(mtmd_context * ctx) { + return ctx->clip_reduced_vram; +} + +extern "C" int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks) { + const size_t n = chunks ? chunks->entries.size() : 0; + for (size_t i = 0; i < n; ++i) { + const mtmd_input_chunk & c = chunks->entries[i]; + if (c.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + int32_t ret = mtmd_encode_chunk(ctx, &c); + if (ret != 0) + return ret; + ctx->has_pre_encoded_image = true; + return 0; + } + } + return 0; // no image, nothing to do +} + +extern "C" void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx) { + if (!ctx->llm_context_initialized && ctx->llm_init_cb) { + ctx->llm_init_cb(ctx->llm_init_user_data); + ctx->llm_context_initialized = true; + } +} + +extern "C" bool mtmd_has_preencoded_image(mtmd_context * ctx) { + return ctx->has_pre_encoded_image; +} + +extern "C" int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk) { + if (!ctx || !chunk) { + return 0; + } + clip_ctx * c = ctx->get_clip_ctx(chunk); + if (!c) { + return 0; + } + return clip_get_image_encode_timing(c); +} + struct mtmd_tokenizer { mtmd_context * ctx; std::vector bitmaps; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 015119be89..7691f32de9 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -58,6 +58,8 @@ struct mtmd_image_tokens; struct mtmd_input_chunk; struct mtmd_input_chunks; +typedef void (*mtmd_llm_init_cb)(void * user_data); + struct mtmd_input_text { const char * text; bool add_special; @@ -86,6 +88,8 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + + bool clip_reduced_vram; // offload clip weights to CPU and stream at runtime to backend device }; MTMD_API const char * mtmd_default_marker(void);