From 462a51945bb535dcdbab505731c8462024c5efba Mon Sep 17 00:00:00 2001 From: deepshnv Date: Fri, 5 Dec 2025 23:25:58 +0530 Subject: [PATCH 1/2] Efficient VLM inference using llama-mtmd-cli for high resolution images while having lower GPU VRAM requirements. Implemented 3 optis to enable this: i) offload vision model weights(only) to CPU and stream to device at runtime ii) reordering LLM model init so that the CLIP model is done with encoding the image and has freed-up the VRAM memory iii) tiled flash attention to avoid 2GB/INT_MAX limit ggml_cuda_cpy for larger images --- common/arg.cpp | 6 + common/common.cpp | 175 ++++++++++++++++++++++++++++ common/common.h | 4 + tools/mtmd/clip-impl.h | 1 + tools/mtmd/clip.cpp | 228 +++++++++++++++++++++++++++++++++++-- tools/mtmd/clip.h | 1 + tools/mtmd/mtmd-cli.cpp | 142 +++++++++++++++++++---- tools/mtmd/mtmd-helper.cpp | 46 ++++++-- tools/mtmd/mtmd.cpp | 68 +++++++++++ tools/mtmd/mtmd.h | 28 +++++ 10 files changed, 656 insertions(+), 43 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 9f3c8a97546..3b40edfde2f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1803,6 +1803,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.mmproj_use_gpu = false; } ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD")); + add_opt(common_arg({ "--clip-reduced-vram" }, + "offload clip weights to CPU and stream at runtime to device (default: false)", + [](common_params & params) { + params.clip_reduced_vram = true; + } + ).set_env("LLAMA_ARG_CLIP_REDUCED_VRAM")); add_opt(common_arg( {"--image", "--audio"}, "FILE", "path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n", diff --git a/common/common.cpp b/common/common.cpp index 0d7fd9a9371..8b8a90367e5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1184,6 +1184,181 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } +// Initialize llama_context and related state using an already-loaded llama_model +// - Does not take ownership of the model passed in; caller manages its lifetime +// - Mirrors common_init_from_params except for model creation/free +struct common_init_result common_init_from_existing_model(common_params & params, struct llama_model * model_in) { + common_init_result iparams; + + if (model_in == NULL) { + LOG_ERR("%s: model '%s' is NULL\n", __func__, params.model.path.c_str()); + return iparams; + } + + common_init_sampler_from_model(model_in, params.sampling); + + const llama_vocab * vocab = llama_model_get_vocab(model_in); + + auto cparams = common_context_params_to_llama(params); + + llama_context * lctx = llama_init_from_model(model_in, cparams); + if (lctx == NULL) { + LOG_ERR("%s: failed to create context with existing model '%s'\n", __func__, params.model.path.c_str()); + return iparams; + } + + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { + LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); + params.ctx_shift = false; + } + + if (!params.control_vectors.empty()) { + if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model_in); + + const auto cvec = common_control_vector_load(params.control_vectors); + if (cvec.n_embd == -1) { + llama_free(lctx); + + return iparams; + } + + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); + if (err) { + llama_free(lctx); + + return iparams; + } + } + + if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) { + bool ok = true; + + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); + ok = false; + } + + bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; + bool has_rerank_prompt = llama_model_chat_template(model_in, "rerank") != NULL; + + if (!has_eos && !has_sep && !has_rerank_prompt) { + LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__); + ok = false; + } else if (!has_eos) { + LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); + } + + if (!ok) { + llama_free(lctx); + + return iparams; + } + } + + // load and optionally apply lora adapters + for (auto & la : params.lora_adapters) { + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model_in, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + llama_free(lctx); + return iparams; + } + + char buf[1024]; + la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; + iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + } + + if (!params.lora_init_without_apply) { + common_set_adapter_lora(lctx, params.lora_adapters); + } + + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } + + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); + } + } + + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + + if (params.sampling.penalty_last_n == -1) { + LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.penalty_last_n = llama_n_ctx(lctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); + } + + if (params.warmup) { + LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); + + llama_set_warmup(lctx, true); + + std::vector tmp; + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + + // some models (e.g. T5) don't have a BOS token + if (bos != LLAMA_TOKEN_NULL) { + tmp.push_back(bos); + } + if (eos != LLAMA_TOKEN_NULL) { + tmp.push_back(eos); + } + if (tmp.empty()) { + tmp.push_back(0); + } + + if (llama_model_has_encoder(model_in)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model_in); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } + if (llama_model_has_decoder(model_in)) { + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); + } + llama_memory_clear(llama_get_memory(lctx), true); + llama_synchronize(lctx); + llama_perf_context_reset(lctx); + llama_set_warmup(lctx, false); + } + + iparams.context.reset(lctx); + + return iparams; +} + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. diff --git a/common/common.h b/common/common.h index 2f23d0baa83..460492b1f56 100644 --- a/common/common.h +++ b/common/common.h @@ -537,6 +537,8 @@ struct common_params { bool has_speculative() const { return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); } + + bool clip_reduced_vram = false; // keep weights' of vision models offloaded to CPU and stream at runtime to required device }; // call once at the start of a program if it uses libcommon @@ -657,6 +659,8 @@ struct common_init_result { }; struct common_init_result common_init_from_params(common_params & params); +// Initialize only the context using an already-loaded llama_model* +struct common_init_result common_init_from_existing_model(common_params & params, struct llama_model * model); struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4a..eaa8f322cbf 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -490,3 +490,4 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { // projector_type clip_get_projector_type(const struct clip_ctx * ctx); +int64_t clip_get_image_encode_timing(const struct clip_ctx * ctx); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d8222d88148..421404da1ca 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -424,11 +424,15 @@ struct clip_ctx { ggml_backend_t backend = nullptr; ggml_backend_t backend_cpu = nullptr; ggml_backend_buffer_ptr buf; + std::vector additional_buffers; // for tensor overrides int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + bool clip_reduced_vram = false; + int64_t image_encode_timing = 0; + // for debugging bool debug_graph = false; std::vector debug_print_tensors; @@ -476,6 +480,8 @@ struct clip_ctx { sched.reset( ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true) ); + + clip_reduced_vram = ctx_params.clip_reduced_vram; } ~clip_ctx() { @@ -762,6 +768,24 @@ struct clip_graph { inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); } + // Check if tiled flash attention is needed to avoid 2GB/INT_MAX ggml_cpy limit in cpy.cu: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + bool needs_tiled_fa = false; + size_t tiled_fa_q_tile = n_pos; + if (use_window_attn && ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + // Calculate max q_tile that keeps mask under 2GB during cast/pad operations + // am taking worst case mask size = n_kv × qlen × element_bytes (worst case F32 = 4 bytes) + const int64_t n_kv = n_pos; + const size_t mask_element_bytes = sizeof(float); + const int64_t max_qlen = (INT_MAX / (n_kv * (int64_t)mask_element_bytes)) & ~63LL; + if ((int64_t)tiled_fa_q_tile > max_qlen && max_qlen >= 64) { + tiled_fa_q_tile = (size_t)max_qlen; + needs_tiled_fa = true; + LOG_INF("%s: will use tiled FA with q_tile=%zu to keep mask under 2GB/INT_MAX (n_pos=%d)\n", __func__, tiled_fa_q_tile, n_pos); + } + } + + ggml_tensor * window_mask_raw = nullptr; // unprepared mask for tiled FA path + if (use_window_attn) { // handle window attention inputs inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); @@ -772,13 +796,19 @@ struct clip_graph { ggml_set_name(window_mask, "window_mask"); ggml_set_input(window_mask); - // if flash attn is used, we need to pad the mask and cast to f16 + // if tiling is needed, keep the raw mask and handle padding/casting per-tile, else pad and cast the full mask if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { - int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1]; - if (n_pad > 0) { - window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0); + if (needs_tiled_fa) { + // Keep raw mask for per-tile processing + window_mask_raw = window_mask; + } + else { + int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1]; + if (n_pad > 0) { + window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0); + } + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); } - window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); } // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] @@ -829,8 +859,17 @@ struct clip_graph { ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + // Use tiled flash attention when needed (calculated before layer loop) + const bool use_tiled_fa = needs_tiled_fa && !full_attn; + + if (use_tiled_fa) { + cur = build_flash_attn_tiled(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, window_mask_raw, kq_scale, tiled_fa_q_tile, il); + } + else { + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } cb(cur, "attn_out", il); } @@ -2370,6 +2409,82 @@ struct clip_graph { return cur; } + // Tiled flash attention to avoid 2GB/INT_MAX ggml_cuda_cpy: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + // Tiles the Q tensor into smaller chunks, processes each with flash attention, then concatenates results and applies output projection before returning cur + ggml_tensor * build_flash_attn_tiled(ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_mask_raw, float kq_scale, size_t q_tile_size, int il) const + { + const int64_t d_head = q_cur->ne[0]; + const int64_t n_head = q_cur->ne[1]; + const int64_t num_positions = q_cur->ne[2]; + const int64_t batch_size = 1; + + ggml_tensor * k_fa = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + ggml_tensor * v_fa = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + if (k_fa->type == GGML_TYPE_F32) { + k_fa = ggml_cast(ctx0, k_fa, GGML_TYPE_F16); + } + if (v_fa->type == GGML_TYPE_F32) { + v_fa = ggml_cast(ctx0, v_fa, GGML_TYPE_F16); + } + + std::vector tile_results; + for (int64_t q0 = 0; q0 < num_positions; q0 += (int64_t)q_tile_size) { + const int64_t qlen = std::min((int64_t)q_tile_size, num_positions - q0); + + // Create Q tile view matching original layout, then permute for FA + // q_cur shape: [d_head, n_head, n_positions] -> view as [d_head, n_head, qlen, 1] + const size_t q_off = (size_t)q0 * q_cur->nb[2]; + ggml_tensor * q_tile = ggml_view_4d(ctx0, q_cur, + d_head, n_head, qlen, batch_size, + q_cur->nb[1], q_cur->nb[2], 0, + q_cur->view_offs + q_off); + // Permute to FA layout: [d_head, n_head, qlen, B] -> [d_head, qlen, n_head, B] + q_tile = ggml_permute(ctx0, q_tile, 0, 2, 1, 3); + + ggml_tensor * mask_tile = nullptr; + if (kq_mask_raw != nullptr) { + const size_t m_off = (size_t)q0 * kq_mask_raw->nb[1]; + mask_tile = ggml_view_2d(ctx0, kq_mask_raw, + kq_mask_raw->ne[0], qlen, + kq_mask_raw->nb[1], + kq_mask_raw->view_offs + m_off); + + // Pad and cast mask tile for flash attention + const int64_t padded_qlen = GGML_PAD(qlen, GGML_KQ_MASK_PAD); + if (qlen < padded_qlen) { + mask_tile = ggml_pad(ctx0, mask_tile, 0, (int)(padded_qlen - qlen), 0, 0); + } + mask_tile = ggml_cast(ctx0, mask_tile, GGML_TYPE_F16); + } + + ggml_tensor * tile_out = ggml_flash_attn_ext(ctx0, q_tile, k_fa, v_fa, mask_tile, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(tile_out, GGML_PREC_F32); + + // Reshape to 2D for concatenation: [d_head * n_head, qlen] + tile_out = ggml_reshape_2d(ctx0, tile_out, tile_out->ne[0] * tile_out->ne[1], tile_out->ne[2]); + tile_results.push_back(tile_out); + } + + // Concatenate all tiles along sequence dimension + ggml_tensor * cur = tile_results[0]; + for (size_t i = 1; i < tile_results.size(); i++) { + cur = ggml_concat(ctx0, cur, tile_results[i], 1); + } + cur = ggml_cont(ctx0, cur); + + cb(cur, "kqv_out_tiled", il); + + // Apply output projection + if (wo) { + cur = ggml_mul_mat(ctx0, wo, cur); + } + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + return cur; + } + // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 @@ -3263,8 +3378,67 @@ struct clip_model_loader { // alloc memory and offload data ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); - ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); - ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_type_t cpu_buft = ggml_backend_get_default_buffer_type(ctx_clip.backend_cpu); + + if (ctx_clip.clip_reduced_vram) { + LOG_INF("%s: clip_reduced_vram enabled, offloading all clip tensors to CPU\n", __func__); + std::map> tensors_by_buft; + // Determine buffer type for each tensor + for (auto & t : tensors_to_load) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); + tensors_by_buft[cpu_buft].push_back(cur); + } + + // For each buffer type, create a context and allocate tensors + bool first_buffer = true; + for (auto & [buft, tensor_list] : tensors_by_buft) + { + // Create a temporary context for this buffer type + struct ggml_init_params temp_params = { + /*.mem_size =*/tensor_list.size() * ggml_tensor_overhead(), + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, + }; + ggml_context_ptr temp_ctx(ggml_init(temp_params)); + if (!temp_ctx) { + throw std::runtime_error(string_format("%s: failed to create temporary context\n", __func__)); + } + + // Create tensor references in the temporary context + for (auto * orig_tensor : tensor_list) { + ggml_tensor * temp_tensor = ggml_dup_tensor(temp_ctx.get(), orig_tensor); + ggml_set_name(temp_tensor, orig_tensor->name); + } + + // Allocate buffer for this context + auto buffer = ggml_backend_buffer_ptr(ggml_backend_alloc_ctx_tensors_from_buft(temp_ctx.get(), buft)); + if (!buffer) { + throw std::runtime_error(string_format("%s: failed to allocate buffer for %s\n", __func__, ggml_backend_buft_name(buft))); + } + ggml_backend_buffer_set_usage(buffer.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + // Copy tensor pointers to original tensors + for (auto * orig_tensor : tensor_list) { + ggml_tensor * temp_tensor = ggml_get_tensor(temp_ctx.get(), orig_tensor->name); + orig_tensor->buffer = temp_tensor->buffer; + orig_tensor->data = temp_tensor->data; + } + + // transfer ownership of the buffer to ctx_clip so it lives as long as the model + if (first_buffer) { + ctx_clip.buf = std::move(buffer); + first_buffer = false; + } else { + ctx_clip.additional_buffers.push_back(std::move(buffer)); + } + } + } + else { + // Baseline: no overrides are needed + ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + for (auto & t : tensors_to_load) { ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); const size_t offset = tensor_offset[t->name]; @@ -4617,6 +4791,7 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 } bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { + int64_t t0 = ggml_time_ms(); const clip_image_f32_batch & imgs = *imgs_c_ptr; int batch_size = imgs.entries.size(); @@ -4628,9 +4803,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // build the inference graph ctx->debug_print_tensors.clear(); - ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); - ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); + + // if clip_reduced_vram enabled: use a temporary scheduler to free VRAM right after encode + ggml_backend_sched_t sched_to_use = nullptr; + ggml_backend_sched_ptr sched_local; + if (ctx->clip_reduced_vram) { + sched_local.reset(ggml_backend_sched_new(ctx->backend_ptrs.data(), ctx->backend_buft.data(), (int) ctx->backend_ptrs.size(), ctx->max_nodes, + /*parallel*/ false, /*op_offload*/ true)); + ggml_backend_sched_reset(sched_local.get()); + sched_to_use = sched_local.get(); + } + else { + ggml_backend_sched_reset(ctx->sched.get()); + sched_to_use = ctx->sched.get(); + } + + ggml_backend_sched_alloc_graph(sched_to_use, gf); // set inputs const auto & model = ctx->model; @@ -4965,7 +5154,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); + auto status = ggml_backend_sched_graph_compute(sched_to_use, gf); if (status != GGML_STATUS_SUCCESS) { LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status); return false; @@ -4997,6 +5186,17 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + // if a temporary scheduler was used, freeing it here releases its compute buffers/VRAM + if (ctx->clip_reduced_vram) { + // ensure all device work is finished before destroying the temporary scheduler + ggml_backend_sched_synchronize(sched_to_use); + // explicit scope reset to free underlying resources now + sched_local.reset(); + // synchronize CPU backend for completeness, is it required? + ggml_backend_synchronize(ctx->backend_cpu); + } + ctx->image_encode_timing = ggml_time_ms() - t0; + return true; } @@ -5116,3 +5316,7 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel batch->entries.push_back(clip_image_f32_ptr(audio)); batch->is_audio = true; } + +int64_t clip_get_image_encode_timing(const struct clip_ctx * ctx) { + return ctx->image_encode_timing; +} diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index c1442afe6b2..65307b63cdd 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -34,6 +34,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + bool clip_reduced_vram; // offload clip weights to CPU and stream at runtime to backend device }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 6679de309b4..c8bd47ed073 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -87,15 +87,114 @@ struct mtmd_cli_context { int n_threads = 1; llama_pos n_past = 0; + common_params saved_params; // keep a copy for JIT LLM init + int32_t saved_n_gpu_layers = 0; + + mtmd_cli_context(common_params & params) { + if (params.clip_reduced_vram) + { + saved_params = params; + saved_n_gpu_layers = params.n_gpu_layers; // save original GPU layers + n_threads = params.cpuparams.n_threads; + n_batch = params.n_batch; + + // Load LLM model with n_gpu_layers=0 (CPU only) initially, will move to n_gpu_layers JIT later + params.n_gpu_layers = 0; + LOG_INF("%s: clip_reduced_vram enabled - loading LLM model with n_gpu_layers=0\n", __func__); + + // Defer LLM context init; still need model to build chat templates + auto mparams = common_model_params_to_llama(params); + model = llama_model_load_from_file(params.model.path.c_str(), mparams); + if (!model) { + exit(1); + } + tmpls = common_chat_templates_init(model, params.chat_template); + use_jinja = params.use_jinja; + chat_history.clear(); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); + + init_vision_context(params); + mtmd_set_llm_init_callback( + ctx_vision.get(), + [](void * user_data) { + auto * self = static_cast(user_data); + self->init_llm_context(self->saved_params); + // pass the fresh lctx back into mtmd for helper functions + mtmd_set_llm_context(self->ctx_vision.get(), self->lctx); + }, + this); + } + else + { + // Baseline + llama_init = common_init_from_params(params); + model = llama_init.model.get(); + lctx = llama_init.context.get(); + vocab = llama_model_get_vocab(model); + smpl = common_sampler_init(model, params.sampling); + n_threads = params.cpuparams.n_threads; + batch = llama_batch_init(1, 0, 1); // batch for next token generation + n_batch = params.n_batch; + + if (!model || !lctx) { + exit(1); + } + + if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) { + LOG_ERR("Model does not have chat template.\n"); + LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n"); + LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n"); + LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n"); + exit(1); + } + + tmpls = common_chat_templates_init(model, params.chat_template); + use_jinja = params.use_jinja; + chat_history.clear(); + LOG_INF("%s: chat template example:\n%s\n", __func__, + common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); + + init_vision_context(params); + + // load antiprompt tokens for legacy templates + if (params.chat_template == "vicuna") { + antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true); + } else if (params.chat_template == "deepseek") { + antiprompt_tokens = common_tokenize(lctx, "###", false, true); + } + } + } + + void init_llm_context(common_params& params) + { + // Free the CPU-loaded model + if (model) { + llama_model_free(model); + model = nullptr; + } + + // Reload model with original n_gpu_layers (GPU offloading) + params.n_gpu_layers = saved_n_gpu_layers; + LOG_INF("%s: reloading LLM model JIT with n_gpu_layers=%d\n", __func__, saved_n_gpu_layers); - mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { - model = llama_init.model.get(); - lctx = llama_init.context.get(); - vocab = llama_model_get_vocab(model); - smpl = common_sampler_init(model, params.sampling); - n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(1, 0, 1); // batch for next token generation - n_batch = params.n_batch; + auto mparams = common_model_params_to_llama(params); + model = llama_model_load_from_file(params.model.path.c_str(), mparams); + if (!model) { + LOG_ERR("Failed to reload LLM model JIT\n"); + exit(1); + } + + // Create context from the GPU-loaded model + auto init2 = common_init_from_existing_model(params, model); + llama_init.model.reset(model); + llama_init.context = std::move(init2.context); + + // refresh raw pointers + model = llama_init.model.get(); + lctx = llama_init.context.get(); + vocab = llama_model_get_vocab(model); + smpl = common_sampler_init(model, params.sampling); + batch = llama_batch_init(1, 0, 1); // batch for next token generation if (!model || !lctx) { exit(1); @@ -108,20 +207,18 @@ struct mtmd_cli_context { LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n"); exit(1); } - - tmpls = common_chat_templates_init(model, params.chat_template); - use_jinja = params.use_jinja; - chat_history.clear(); - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); - - init_vision_context(params); - + // load antiprompt tokens for legacy templates if (params.chat_template == "vicuna") { antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true); } else if (params.chat_template == "deepseek") { antiprompt_tokens = common_tokenize(lctx, "###", false, true); } + + if (!ctx_vision.get()) { + LOG_ERR("Failed to load vision model from %s\n", params.mmproj.path); + exit(1); + } } ~mtmd_cli_context() { @@ -132,12 +229,13 @@ struct mtmd_cli_context { void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params.mmproj_use_gpu; - mparams.print_timings = true; - mparams.n_threads = params.cpuparams.n_threads; - mparams.flash_attn_type = params.flash_attn_type; - mparams.image_min_tokens = params.image_min_tokens; - mparams.image_max_tokens = params.image_max_tokens; + mparams.use_gpu = params.mmproj_use_gpu; + mparams.print_timings = true; + mparams.n_threads = params.cpuparams.n_threads; + mparams.flash_attn_type = params.flash_attn_type; + mparams.image_min_tokens = params.image_min_tokens; + mparams.image_max_tokens = params.image_max_tokens; + mparams.clip_reduced_vram = params.clip_reduced_vram; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index f0891bba30d..07469b24938 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -333,7 +333,10 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, if (logits_last && is_last_token) { text_batch.logits[text_batch.n_tokens - 1] = true; } - ret = llama_decode(lctx, text_batch); + { + struct llama_context * lctx_eff = mtmd_get_llm_context(ctx) ? mtmd_get_llm_context(ctx) : lctx; + ret = llama_decode(lctx_eff, text_batch); + } if (ret != 0) { LOG_ERR("failed to decode text\n"); llama_batch_free(text_batch); @@ -347,18 +350,30 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, int64_t t0 = ggml_time_ms(); LOG_INF("encoding %s slice...\n", name); - - ret = mtmd_encode_chunk(ctx, chunk); - if (ret != 0) { - LOG_ERR("failed to encode %s slice\n", name); - llama_batch_free(text_batch); - return ret; + // Skip encode if we have pre-encoded the same image (identified by id) + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE && mtmd_has_preencoded_image(ctx)) { + LOG_INF("using pre-encoded image embeddings\n"); + } + else { + ret = mtmd_encode_chunk(ctx, chunk); + if (ret != 0) { + LOG_ERR("failed to encode %s slice\n", name); + llama_batch_free(text_batch); + return ret; + } } - LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (mtmd_preencode_enabled(ctx)) { + LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, mtmd_get_image_encode_timing(ctx, chunk)); + } else { + LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + } float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + { + struct llama_context * lctx_eff = mtmd_get_llm_context(ctx) ? mtmd_get_llm_context(ctx) : lctx; + ret = mtmd_helper_decode_image_chunk(ctx, lctx_eff, chunk, embd, n_past, seq_id, n_batch, new_n_past); + } if (ret != 0) { LOG_ERR("failed to decode %s\n", name); llama_batch_free(text_batch); @@ -386,6 +401,19 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, return 0; } + // When clip_reduced_vram is enabled, first pre-encode the IMAGE chunk and then release CLIP VRAM before any llama_decode runs. + // This preserves downstream decode order but frees vision VRAM earlier. + if (mtmd_preencode_enabled(ctx)) { + LOG_INF("pre-encoding image before any text decode...\n"); + int32_t ret = mtmd_preencode_image(ctx, chunks); + if (ret != 0) { + LOG_ERR("failed to pre-encode image\n"); + return ret; + } + // Invoke JIT LLM initialization callback after CLIP has freed VRAM + mtmd_invoke_llm_init_if_needed(ctx); + } + for (size_t i = 0; i < n_chunks; i++) { bool chunk_logits_last = (i == n_chunks - 1) && logits_last; auto chunk = mtmd_input_chunks_get(chunks, i); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 6690bf30046..f04a622b6ef 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -110,6 +110,7 @@ mtmd_context_params mtmd_context_params_default() { /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* clip_reduced_vram */ false, }; return params; } @@ -125,6 +126,10 @@ struct mtmd_context { std::string media_marker; const int n_embd_text; + bool clip_reduced_vram = false; + bool has_pre_encoded_image = false; + bool llm_context_initialized = false; + // these are not token, but strings used to mark the beginning and end of image/audio embeddings std::string img_beg; std::string img_end; @@ -153,6 +158,11 @@ struct mtmd_context { // for whisper, we pre-calculate the mel filter bank whisper_preprocessor::whisper_filters w_filters; + // JIT llm init integration when clip_reduced_vram is enabled + mtmd_llm_init_cb llm_init_cb = nullptr; + void * llm_init_user_data = nullptr; + llama_context * llm_lctx = nullptr; // set via mtmd_set_llm_context when clip_reduced_vram is enabled + // TODO @ngxson : add timings mtmd_context(const char * mmproj_fname, @@ -177,6 +187,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* clip_reduced_vram */ ctx_params.clip_reduced_vram, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -186,6 +197,9 @@ struct mtmd_context { throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname)); } + // store for helper flow control + clip_reduced_vram = ctx_params.clip_reduced_vram; + // if both vision and audio mmproj are present, we need to validate their n_embd if (ctx_v && ctx_a) { int n_embd_v = clip_n_mmproj_embd(ctx_v); @@ -400,6 +414,60 @@ void mtmd_free(mtmd_context * ctx) { delete ctx; } +extern "C" MTMD_API void mtmd_set_llm_init_callback(mtmd_context * ctx, mtmd_llm_init_cb cb, void * user_data) { + ctx->llm_init_cb = cb; + ctx->llm_init_user_data = user_data; +} + +extern "C" MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_context * lctx) { + ctx->llm_lctx = lctx; +} + +extern "C" MTMD_API struct llama_context * mtmd_get_llm_context(mtmd_context * ctx) { + return ctx->llm_lctx; +} + +extern "C" MTMD_API bool mtmd_preencode_enabled(mtmd_context * ctx) { + return ctx->clip_reduced_vram; +} + +extern "C" MTMD_API int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks) { + const size_t n = chunks ? chunks->entries.size() : 0; + for (size_t i = 0; i < n; ++i) { + const mtmd_input_chunk & c = chunks->entries[i]; + if (c.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + int32_t ret = mtmd_encode_chunk(ctx, &c); + if (ret != 0) + return ret; + ctx->has_pre_encoded_image = true; + return 0; + } + } + return 0; // no image, nothing to do +} + +extern "C" MTMD_API void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx) { + if (!ctx->llm_context_initialized && ctx->llm_init_cb) { + ctx->llm_init_cb(ctx->llm_init_user_data); + ctx->llm_context_initialized = true; + } +} + +extern "C" MTMD_API bool mtmd_has_preencoded_image(mtmd_context * ctx) { + return ctx->has_pre_encoded_image; +} + +extern "C" MTMD_API int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk) { + if (!ctx || !chunk) { + return 0; + } + clip_ctx * c = ctx->get_clip_ctx(chunk); + if (!c) { + return 0; + } + return clip_get_image_encode_timing(c); +} + struct mtmd_tokenizer { mtmd_context * ctx; std::vector bitmaps; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 015119be897..0b87cad1b24 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -58,6 +58,8 @@ struct mtmd_image_tokens; struct mtmd_input_chunk; struct mtmd_input_chunks; +typedef void (*mtmd_llm_init_cb)(void * user_data); + struct mtmd_input_text { const char * text; bool add_special; @@ -86,6 +88,8 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + + bool clip_reduced_vram; // offload clip weights to CPU and stream at runtime to backend device }; MTMD_API const char * mtmd_default_marker(void); @@ -100,6 +104,30 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, MTMD_API void mtmd_free(mtmd_context * ctx); +// Register a JIT initializer for the LLM context; intended to be invoked after image encode +MTMD_API void mtmd_set_llm_init_callback(mtmd_context * ctx, mtmd_llm_init_cb cb, void * user_data); + +// Provide the initialized llama context back to mtmd for use in other methods +MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_context * lctx); + +// Retrieve the JIT-initialized llama context if available (NULL if not set) +MTMD_API struct llama_context * mtmd_get_llm_context(mtmd_context * ctx); + +// Whether pre-encode/JIT-llm flow is enabled +MTMD_API bool mtmd_preencode_enabled(mtmd_context * ctx); + +// Pre-encode the image chunk; returns 0 on success, or encode error +MTMD_API int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks); + +// Invoke the registered JIT LLM init callback if not already invoked +MTMD_API void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx); + +// Query pre-encoded image state and identity +MTMD_API bool mtmd_has_preencoded_image(mtmd_context * ctx); + +// Retrieve the encode timing (ms) for the media chunk's underlying encoder, returns 0 if unavailable +MTMD_API int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk); + // whether we need to set non-causal mask before llama_decode MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); From 42974cc2acca81d1e06a6dc544aebaa86cd5fb38 Mon Sep 17 00:00:00 2001 From: deepshnv Date: Mon, 8 Dec 2025 19:30:36 +0530 Subject: [PATCH 2/2] moved JIT llm init helpers from mtmd.h to mtmd-helper.h. Only mtmd_set_llm_init_callback and mtmd_set_llm_context remain exported as MTMD_API as they are needed in mtmd-cli.cpp, rest other JIT functions are now internal --- tools/mtmd/mtmd-helper.cpp | 22 ++++++++++++++++++++++ tools/mtmd/mtmd-helper.h | 6 ++++++ tools/mtmd/mtmd.cpp | 12 ++++++------ tools/mtmd/mtmd.h | 24 ------------------------ 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 07469b24938..da009ed46ad 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -32,6 +32,28 @@ #define STB_IMAGE_IMPLEMENTATION #include "stb/stb_image.h" + +// internal functions for JIT llm init integration +extern "C" { + // Retrieve the JIT-initialized llama context if available (NULL if not set) + struct llama_context * mtmd_get_llm_context(mtmd_context * ctx); + + // Whether pre-encode/JIT-llm flow is enabled + bool mtmd_preencode_enabled(mtmd_context * ctx); + + // Pre-encode the image chunk; returns 0 on success, or encode error + int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks); + + // Invoke the registered JIT LLM init callback if not already invoked + void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx); + + // Query pre-encoded image state and identity + bool mtmd_has_preencoded_image(mtmd_context * ctx); + + // Retrieve the encode timing (ms) for the media chunk's underlying encoder, returns 0 if unavailable + int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk); +} + // // internal logging functions // diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 5036b92442a..25c039dc852 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -85,6 +85,12 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, int32_t n_batch, llama_pos * new_n_past); +// Register a JIT initializer for the LLM context; intended to be invoked after image encode +MTMD_API void mtmd_set_llm_init_callback(mtmd_context * ctx, mtmd_llm_init_cb cb, void * user_data); + +// Provide the initialized llama context back to mtmd for use in other methods +MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_context * lctx); + #ifdef __cplusplus } // extern "C" #endif diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f04a622b6ef..835e4003f89 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -423,15 +423,15 @@ extern "C" MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_c ctx->llm_lctx = lctx; } -extern "C" MTMD_API struct llama_context * mtmd_get_llm_context(mtmd_context * ctx) { +extern "C" struct llama_context * mtmd_get_llm_context(mtmd_context * ctx) { return ctx->llm_lctx; } -extern "C" MTMD_API bool mtmd_preencode_enabled(mtmd_context * ctx) { +extern "C" bool mtmd_preencode_enabled(mtmd_context * ctx) { return ctx->clip_reduced_vram; } -extern "C" MTMD_API int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks) { +extern "C" int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks) { const size_t n = chunks ? chunks->entries.size() : 0; for (size_t i = 0; i < n; ++i) { const mtmd_input_chunk & c = chunks->entries[i]; @@ -446,18 +446,18 @@ extern "C" MTMD_API int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_ return 0; // no image, nothing to do } -extern "C" MTMD_API void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx) { +extern "C" void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx) { if (!ctx->llm_context_initialized && ctx->llm_init_cb) { ctx->llm_init_cb(ctx->llm_init_user_data); ctx->llm_context_initialized = true; } } -extern "C" MTMD_API bool mtmd_has_preencoded_image(mtmd_context * ctx) { +extern "C" bool mtmd_has_preencoded_image(mtmd_context * ctx) { return ctx->has_pre_encoded_image; } -extern "C" MTMD_API int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk) { +extern "C" int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk) { if (!ctx || !chunk) { return 0; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 0b87cad1b24..7691f32de9f 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -104,30 +104,6 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, MTMD_API void mtmd_free(mtmd_context * ctx); -// Register a JIT initializer for the LLM context; intended to be invoked after image encode -MTMD_API void mtmd_set_llm_init_callback(mtmd_context * ctx, mtmd_llm_init_cb cb, void * user_data); - -// Provide the initialized llama context back to mtmd for use in other methods -MTMD_API void mtmd_set_llm_context(mtmd_context * ctx, struct llama_context * lctx); - -// Retrieve the JIT-initialized llama context if available (NULL if not set) -MTMD_API struct llama_context * mtmd_get_llm_context(mtmd_context * ctx); - -// Whether pre-encode/JIT-llm flow is enabled -MTMD_API bool mtmd_preencode_enabled(mtmd_context * ctx); - -// Pre-encode the image chunk; returns 0 on success, or encode error -MTMD_API int32_t mtmd_preencode_image(mtmd_context * ctx, const mtmd_input_chunks * chunks); - -// Invoke the registered JIT LLM init callback if not already invoked -MTMD_API void mtmd_invoke_llm_init_if_needed(mtmd_context * ctx); - -// Query pre-encoded image state and identity -MTMD_API bool mtmd_has_preencoded_image(mtmd_context * ctx); - -// Retrieve the encode timing (ms) for the media chunk's underlying encoder, returns 0 if unavailable -MTMD_API int64_t mtmd_get_image_encode_timing(mtmd_context * ctx, const mtmd_input_chunk * chunk); - // whether we need to set non-causal mask before llama_decode MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);