From 102787d19b1bc73034de6c675382c990359ef458 Mon Sep 17 00:00:00 2001 From: HimariO Date: Sun, 2 Feb 2025 18:02:07 +0800 Subject: [PATCH 01/12] implment vision model architecture, gguf convertor --- examples/llava/clip-impl.h | 5 + examples/llava/clip.cpp | 135 +++++++++++++++++++---- examples/llava/qwen2_vl_surgery.py | 168 ++++++++++++++++++----------- 3 files changed, 222 insertions(+), 86 deletions(-) diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 4d7340a56bd0c..55b18dc893af0 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -26,6 +26,8 @@ #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_SILU "clip.use_silu" +#define KEY_USE_GLU_MLP "clip.use_glu_mlp" +#define KEY_USE_RMS_NORM "clip.use_rms_norm" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_BLOCK "clip.%s.block_count" @@ -44,6 +46,8 @@ #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" // @@ -62,6 +66,7 @@ #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" +#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 49c90b7506e73..986db312f2e8f 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -166,6 +166,7 @@ struct clip_hparams { std::vector image_grid_pinpoints; int32_t image_crop_resolution; std::unordered_set vision_feature_layer; + std::vector full_attn_layers; }; struct clip_layer { @@ -191,6 +192,9 @@ struct clip_layer { struct ggml_tensor * ff_o_w = nullptr; struct ggml_tensor * ff_o_b = nullptr; + struct ggml_tensor * ff_g_w = NULL; + struct ggml_tensor * ff_g_b = NULL; + // layernorm 2 struct ggml_tensor * ln_2_w = nullptr; struct ggml_tensor * ln_2_b = nullptr; @@ -314,6 +318,9 @@ struct clip_ctx { float image_std[3]; bool use_gelu = false; bool use_silu = false; + bool use_glu_mlp = false; + bool use_rms_norm = false; + int32_t ftype = 1; gguf_context_ptr ctx_gguf; ggml_context_ptr ctx_data; @@ -552,6 +559,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; const float eps = hparams.eps; + const bool use_window_attn = hparams.full_attn_layers.size() > 0; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs.entries.size(); @@ -604,8 +612,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } - struct ggml_tensor * embeddings = inp; - struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * window_mask = nullptr; + struct ggml_tensor * window_idx = nullptr; if (ctx->has_llava_projector) { // concat class_embeddings and patch_embeddings @@ -657,6 +667,28 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const auto & vision_feature_layer = hparams.vision_feature_layer; // loop over layers + + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + + positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 4 / 4); + positions = ggml_get_rows(ctx0, positions, window_idx); + positions = ggml_reshape_1d(ctx0, positions, num_position_ids); + } + for (int il = 0; il < ctx->max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states @@ -669,9 +701,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im //const size_t nb_q_w = model.layers[il].q_w->nb[0]; // layernorm1 - { + if (ctx->use_rms_norm) { + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w); + } + else { cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); } @@ -711,7 +746,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end(); + const bool full_attn = use_window_attn ? inlist : true; + if (full_attn) { + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + } else { + KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); + } + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -728,25 +770,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = cur; // embeddings = residual, cur = hidden_states // layernorm2 - { + if (ctx->use_rms_norm) { + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w); + } else { cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + // mlp + if (ctx->use_glu_mlp) { + // ffn_up + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); + + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + if (ctx->use_gelu) { + cur_gate = ggml_gelu_inplace(ctx0, cur_gate); + } else if (ctx->use_silu) { + cur_gate = ggml_silu_inplace(ctx0, cur_gate); + } else { + cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate); + } + cur = ggml_mul(ctx0, cur_gate, cur_up); - if (ctx->use_gelu) { - cur = ggml_gelu_inplace(ctx0, cur); - } else if (ctx->use_silu) { - cur = ggml_silu_inplace(ctx0, cur); - } else { - cur = ggml_gelu_quick_inplace(ctx0, cur); + // ffn_down + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); } + else { + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + if (ctx->use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else if (ctx->use_silu) { + cur = ggml_silu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + } // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -756,10 +823,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // post-layernorm if (model.post_ln_w) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "post_ln"); + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } } // final layer is a vision feature layer @@ -1073,6 +1147,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + if (use_window_attn) { + struct ggml_tensor * inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); + embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); + } + // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1175,6 +1261,8 @@ struct clip_model_loader { get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); + get_bool(KEY_USE_GLU_MLP, ctx_clip.use_glu_mlp, false); + get_bool(KEY_USE_RMS_NORM, ctx_clip.use_rms_norm, false); auto & hparams = ctx_clip.vision_model.hparams; get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); @@ -1187,6 +1275,7 @@ struct clip_model_loader { get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false); { std::string mm_patch_merge_type; @@ -1302,14 +1391,16 @@ struct clip_model_loader { layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); + layer.ff_g_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), ctx_clip.use_glu_mlp); layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); - layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); - layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); + layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), !ctx_clip.use_rms_norm); + layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), !ctx_clip.use_rms_norm); layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); + layer.ff_g_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), ctx_clip.use_glu_mlp); } switch (ctx_clip.proj_type) { diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index c87606b4fdf4f..8f7a94e5c3797 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -5,10 +5,12 @@ import numpy as np from gguf import * from transformers import ( + AutoProcessor, Qwen2VLForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor, - AutoProcessor, - Qwen2VLConfig + Qwen2VLConfig, + Qwen2_5_VLConfig, ) @@ -18,62 +20,80 @@ def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) +class VL2: + + @staticmethod + def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") + # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[to_gguf_name] {og} --> {name}") + return name + + @classmethod + def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]: + vision_model = qwen2vl.visual + tensor_map = {} + for name, ten in vision_model.state_dict().items(): + ten = ten.numpy() + if 'qkv' in name: + if ten.ndim == 2: # weight + c3, _ = ten.shape + else: # bias + c3 = ten.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = ten[:c] + wk = ten[c: c * 2] + wv = ten[c * 2:] + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv + elif 'merger' in name: + if name.endswith("ln_q.weight"): + tensor_map['v.post_ln.weight'] = ten + elif name.endswith("ln_q.bias"): + tensor_map['v.post_ln.bias'] = ten + else: + # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" + tensor_map[cls.to_gguf_name(name)] = ten + elif 'patch_embed.proj.weight' in name: + # NOTE: split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = ten.shape + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] + tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] + else: + tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten -def to_gguf_name(name: str) -> str: - og = name - name = name.replace("text_model", "t").replace("vision_model", "v") - name = name.replace("blocks", "blk").replace("embeddings.", "") - name = name.replace("attn.", "attn_") - name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") - # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") - name = name.replace("norm1", "ln1").replace("norm2", "ln2") - name = name.replace("merger.mlp", 'mm') - print(f"[to_gguf_name] {og} --> {name}") - return name - - -def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: - vision_model = qwen2vl.visual - tensor_map = {} - for name, ten in vision_model.state_dict().items(): - ten = ten.numpy() - if 'qkv' in name: - if ten.ndim == 2: # weight - c3, _ = ten.shape - else: # bias - c3 = ten.shape[0] - assert c3 % 3 == 0 - c = c3 // 3 - wq = ten[:c] - wk = ten[c: c * 2] - wv = ten[c * 2:] - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv - elif 'merger' in name: - if name.endswith("ln_q.weight"): - tensor_map['v.post_ln.weight'] = ten - elif name.endswith("ln_q.bias"): - tensor_map['v.post_ln.bias'] = ten + for new_name, ten in tensor_map.items(): + if ten.ndim <= 1 or new_name.endswith("_norm.weight"): + tensor_map[new_name] = ten.astype(np.float32) else: - # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" - tensor_map[to_gguf_name(name)] = ten - elif 'patch_embed.proj.weight' in name: - # NOTE: split Conv3D into Conv2Ds - c1, c2, kt, kh, kw = ten.shape - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" - tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] - tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] - else: - tensor_map[to_gguf_name(f"vision_model.{name}")] = ten - - for new_name, ten in tensor_map.items(): - if ten.ndim <= 1 or new_name.endswith("_norm.weight"): - tensor_map[new_name] = ten.astype(np.float32) - else: - tensor_map[new_name] = ten.astype(dtype) - tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder - return tensor_map + tensor_map[new_name] = ten.astype(dtype) + tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder + return tensor_map + + +class VL25(VL2): + + @staticmethod + def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up") + name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[vl25][to_gguf_name] {og} --> {name}") + return name def main(args): @@ -92,11 +112,18 @@ def main(args): model_path = "" model_name = args.model_name print("model_name: ", model_name) - qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype, device_map="cpu" - ) - cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] - vcfg = cfg.vision_config + if args.model_type == "qwen2vl": + qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config + else: + qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config if os.path.isdir(model_name): local_model = True @@ -125,14 +152,26 @@ def main(args): else: raise ValueError() - tensor_map = find_vision_tensors(qwen2vl, np_dtype) + if args.model_type == "qwen2.5vl": + fout.add_bool("clip.use_glu_mlp", True) # gate linear unit MLP layer in vision model + fout.add_bool("clip.use_rms_norm", True) + fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes) + fout.add_uint32("clip.vision.window_size", vcfg.window_size) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) + fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size) + else: + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) + fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) + + if args.model_type == "qwen2.5vl": + tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype) + else: + tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype) for name, data in tensor_map.items(): fout.add_tensor(name, data) fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) - fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) @@ -160,6 +199,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl") parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") args = parser.parse_args() main(args) From 3f2ee541df36a78dca6ec22e70a1f4b27b6d1fb0 Mon Sep 17 00:00:00 2001 From: HimariO Date: Mon, 3 Feb 2025 01:58:53 +0800 Subject: [PATCH 02/12] handle window attention inputs --- examples/llava/clip.cpp | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 986db312f2e8f..0b3b8e70fc2ec 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -166,6 +166,7 @@ struct clip_hparams { std::vector image_grid_pinpoints; int32_t image_crop_resolution; std::unordered_set vision_feature_layer; + int32_t attn_window_size; std::vector full_attn_layers; }; @@ -1274,6 +1275,7 @@ struct clip_model_loader { get_u32(KEY_IMAGE_SIZE, hparams.image_size); get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false); @@ -2590,6 +2592,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); free(data); } + if (ctx->has_minicpmv_projector) { { // inspired from siglip: @@ -2708,6 +2711,64 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn? + struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); + struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); + struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); + + const int merge_ratio = 2; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int grid_window = hparams.attn_window_size / hparams.patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + /* + pw * ph = number of tokens output by ViT after apply patch merger + ipw * ipw = number of vision token been processed inside ViT + */ + + std::vector idx(ph * pw); + std::vector inv_idx(ph * pw); + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y+=grid_window) + { + for (int x = 0; x < pw; x+=grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx[src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + } + ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); From 9d06730a3526b65ea40f9a75d1aa55dd5583bb5d Mon Sep 17 00:00:00 2001 From: HimariO Date: Tue, 4 Feb 2025 22:24:27 +0800 Subject: [PATCH 03/12] add support for `Qwen2_5_VLForConditionalGeneration` --- convert_hf_to_gguf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bf97475f78dd..be97c8da10691 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2375,6 +2375,11 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: yield name, data +@Model.register("Qwen2_5_VLForConditionalGeneration") +class Qwen25VLModel(Qwen2VLModel): + model_arch = gguf.MODEL_ARCH.QWEN2VL + + @Model.register("WavTokenizerDec") class WavTokenizerDecModel(Model): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC From e9043cf3440f67471e336bb0da804cda81e554f7 Mon Sep 17 00:00:00 2001 From: HimariO Date: Sat, 15 Mar 2025 23:03:50 +0800 Subject: [PATCH 04/12] add debug utils --- examples/llava/qwen2vl-cli.cpp | 427 ++++++++++++++++++++++++++++++++- 1 file changed, 416 insertions(+), 11 deletions(-) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index eca7b7f10b9e3..a870bc4be3db5 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -23,6 +23,9 @@ #include #include #include +#include +#include +#include static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, @@ -483,33 +486,431 @@ static void debug_test_mrope_2d() { ggml_backend_free(backend); } -static void debug_dump_img_embed(struct llava_context * ctx_llava) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); - int ne = n_embd * 4; - float vals[56 * 56 * 3]; +static void debug_patch_layout() { + // 1. Initialize backend + ggml_backend_t backend = NULL; + std::string backend_name = ""; +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + backend_name = "cuda"; + if (!backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif + // if there aren't GPU Backends fallback to CPU backend + if (!backend) { + backend = ggml_backend_cpu_init(); + backend_name = "cpu"; + } + + // Calculate the size needed to allocate + size_t ctx_size = 0; + ctx_size += 2 * ggml_tensor_overhead(); // tensors + // no need to allocate anything else! + + // 2. Allocate `ggml_context` to store tensor data + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() + }; + struct ggml_context * ctx = ggml_init(params); + + const int patches_w = 14; + const int patches_h = 10; + const int c = 2; + const int batch_size = 1; + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, patches_w, patches_h, c, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + + std::vector dummy_q; + dummy_q.resize(patches_w * patches_h * c * batch_size); + for (size_t i = 0; i < patches_h * patches_w * c; i++) + { + dummy_q[i] = i; + } + + // std::fill(dummy_q.begin(), dummy_q.end(), 0.1); + // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); + + // 4. Allocate a `ggml_backend_buffer` to store all tensors + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + // 5. Copy tensor data from main memory (RAM) to backend buffer + ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw)); + + // 6. Create a `ggml_cgraph` for mul_mat operation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx0 = NULL; + + // create a temporally context to build the graph + struct ggml_init_params params0 = { + /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx0 = ggml_init(params0); + gf = ggml_new_graph(ctx0); + /* + Compute graph + */ + struct ggml_tensor * inp = ggml_cont(ctx0, ggml_permute(ctx0, inp_raw, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] + + inp = ggml_reshape_4d( + ctx0, inp, + c * 2, patches_w / 2, patches_h, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + c * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); + inp = ggml_reshape_3d( + ctx0, inp, + c, patches_w * patches_h, batch_size); + + // Add "result" tensor and all of its dependencies to the cgraph + ggml_build_forward_expand(gf, inp); + + // 7. Create a `ggml_gallocr` for cgraph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // 9. Run the computation + int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + ggml_backend_graph_compute(backend, gf); + + // 10. Retrieve results (output tensors) + // in this example, output tensor is always the last tensor in the graph + struct ggml_tensor * result = inp; + // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1]; + float * result_data = (float *)malloc(ggml_nbytes(result)); + // because the tensor data is stored in device buffer, we need to copy it back to RAM + ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); + const std::string bin_file = "patch_layout_" + backend_name +".bin"; + std::ofstream outFile(bin_file, std::ios::binary); + + if (outFile.is_open()) { + outFile.write(reinterpret_cast(result_data), ggml_nbytes(result)); + outFile.close(); + std::cout << "Data successfully written to " + bin_file << std::endl; + } else { + std::cerr << "Error opening file!" << std::endl; + } + + free(result_data); + // 11. Free memory and exit + ggml_free(ctx0); + ggml_gallocr_free(allocr); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); +} + +static void debug_test_get_rows() { + // 1. Initialize backend + ggml_backend_t backend = NULL; + std::string backend_name = ""; +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + backend_name = "cuda"; + if (!backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif + // if there aren't GPU Backends fallback to CPU backend + if (!backend) { + backend = ggml_backend_cpu_init(); + backend_name = "cpu"; + } + + // Calculate the size needed to allocate + size_t ctx_size = 0; + ctx_size += 128 * ggml_tensor_overhead(); // tensors + // no need to allocate anything else! + + // 2. Allocate `ggml_context` to store tensor data + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() + }; + struct ggml_context * ctx = ggml_init(params); + + const int tokens = 30; + struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 3, tokens * 2); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * pos = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 4, tokens); + // struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, tokens * 4); + ggml_set_name(pos, "pos"); + ggml_set_input(pos); + + struct ggml_tensor * ind = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, tokens); + ggml_set_name(ind, "ind"); + ggml_set_input(ind); + + struct ggml_tensor * ind_2d = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 1, tokens); + ggml_set_name(ind_2d, "ind_2d"); + ggml_set_input(ind_2d); + + std::vector dummy_q; + dummy_q.resize(128 * 3 * inp_raw->ne[2]); + for (int i = 0; i < inp_raw->ne[2]; i ++) { + for (int j = 0; j < 3; j ++) { + int offset = i * 128 * 3 + j * 128; + std::fill(dummy_q.begin() + offset, dummy_q.begin() + offset + 128, 0.1 * i); + } + } + // std::fill(dummy_q.begin(), dummy_q.end(), 0.1); + // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); + + std::vector pos_id; + pos_id.resize(tokens * 4); + for (int i = 0; i < tokens; i ++) { + pos_id[i] = i; + pos_id[i + tokens * 1] = i + 10; + pos_id[i + tokens * 2] = i + 20; + pos_id[i + tokens * 3] = i + 30; + } + + std::vector remap_ind; + remap_ind.resize(tokens * 4); + for (int i = 0; i < tokens; i ++) { + remap_ind[i] = tokens - i - 1; + } + + // 4. Allocate a `ggml_backend_buffer` to store all tensors + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + // 5. Copy tensor data from main memory (RAM) to backend buffer + ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw)); + ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos)); + ggml_backend_tensor_set(ind, remap_ind.data(), 0, ggml_nbytes(ind)); + ggml_backend_tensor_set(ind_2d, remap_ind.data(), 0, ggml_nbytes(ind_2d)); + + // 6. Create a `ggml_cgraph` for mul_mat operation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx_cgraph = NULL; + + // create a temporally context to build the graph + struct ggml_init_params params0 = { + /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx_cgraph = ggml_init(params0); + gf = ggml_new_graph(ctx_cgraph); + + // ne = [128, 1, 30, 1] + auto x = ggml_reshape_2d(ctx_cgraph, inp_raw, 128 * 3 * 2, tokens); + struct ggml_tensor * result0 = ggml_get_rows( + ctx_cgraph, x, ind); + result0 = ggml_reshape_3d(ctx_cgraph, result0, 128, 3, tokens * 2); + + struct ggml_tensor * result1 = ggml_get_rows( + ctx_cgraph, pos, ind); + + // Add "result" tensor and all of its dependencies to the cgraph + ggml_build_forward_expand(gf, result0); + ggml_build_forward_expand(gf, result1); + + // 7. Create a `ggml_gallocr` for cgraph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // 9. Run the computation + int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + ggml_backend_graph_compute(backend, gf); + + // 10. Retrieve results (output tensors) + // in this example, output tensor is always the last tensor in the graph + struct ggml_tensor * result = result0; + // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1]; + float * result_data = (float *)malloc(ggml_nbytes(result)); + // because the tensor data is stored in device buffer, we need to copy it back to RAM + ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); + const std::string bin_file = "getrows_" + backend_name +"_0.bin"; + std::ofstream outFile(bin_file, std::ios::binary); + + if (outFile.is_open()) { + outFile.write(reinterpret_cast(result_data), ggml_nbytes(result)); + outFile.close(); + std::cout << "Data successfully written to " + bin_file << std::endl; + } else { + std::cerr << "Error opening file!" << std::endl; + } + + free(result_data); + // 11. Free memory and exit + ggml_free(ctx_cgraph); + ggml_gallocr_free(allocr); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); +} + + +enum model_output_type { + conv3d, + patch_embed, + patch_win_attn_scatter, + first_attn_layer, + last_attn_layer, + attn_softmax, + final_layer, +}; + +static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) { + int ih = 140; + int iw = 196; + // int ih = 56; + // int iw = 56; + // int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); + int n_embd = 1280; + int merge = 1; + if (output_type == model_output_type::final_layer) { + n_embd = 2048; + merge = 2; + } + else if (output_type == model_output_type::attn_softmax) { + merge = 1; + n_embd = (ih/14/merge) * (iw/14/merge) * 16; + } + + int ne = (ih/14/merge) * (iw/14/merge) * n_embd; + float vals[iw * ih * 3]; // float embd[ne]; std::vector embd; embd.resize(ne); - for (int i = 0; i < 56*56; i++) + for (int i = 0; i < iw*ih; i++) { for (int c = 0; c < 3; c++) - vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + vals[i * 3 + c] = (float)i / (iw*ih); } - clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data()); + clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data()); + + std::string file_postfix = ""; + switch (output_type) + { + case model_output_type::conv3d: + file_postfix = "conv3d"; + break; + case model_output_type::patch_embed: + file_postfix = "patch_embed"; + break; + case model_output_type::patch_win_attn_scatter: + file_postfix = "scatter"; + break; + case model_output_type::first_attn_layer: + file_postfix = "first_attn"; + break; + case model_output_type::last_attn_layer: + file_postfix = "last_attn"; + break; + case model_output_type::attn_softmax: + file_postfix = "attn_softmax"; + break; + case model_output_type::final_layer: + file_postfix = "final"; + break; + default: + break; + } + auto output_path = "img_embed_" + file_postfix + ".bin"; - std::ofstream outFile("img_embed.bin", std::ios::binary); + std::ofstream outFile(output_path, std::ios::binary); if (outFile.is_open()) { outFile.write(reinterpret_cast(embd.data()), ne * sizeof(float)); outFile.close(); - std::cout << "Data successfully written to mrope.bin" << std::endl; + std::cout << "Data successfully written to ::[ " << output_path << std::endl; + } else { + std::cerr << "Error opening file!" << std::endl; + } +} + + +static void dump_win_attn_mask() { + const int image_size_width = 196; + const int image_size_height = 140; + const int patch_size = 14; + const int attn_window_size = 112; + + const int merge_ratio = 2; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int grid_window = attn_window_size / patch_size / merge_ratio; + /* + pw * ph = number of tokens output by ViT after apply patch merger + ipw * ipw = number of vision token been processed inside ViT + */ + + std::vector idx(ph * pw); + std::vector inv_idx(ph * pw); + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + int ne = pow(ipw * iph, 2); + std::vector mask(ne, std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y+=grid_window) + { + for (int x = 0; x < pw; x+=grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx[src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + auto output_path = "win_attn_mask_fp32.bin"; + + std::ofstream outFile(output_path, std::ios::binary); + if (outFile.is_open()) { + outFile.write(reinterpret_cast(mask.data()), ne * sizeof(float)); + + outFile.close(); + std::cout << "Data successfully written to " << output_path << std::endl; } else { std::cerr << "Error opening file!" << std::endl; } } + #endif @@ -551,8 +952,12 @@ int main(int argc, char ** argv) { } else if (params.image[0].empty()) { auto ctx_llava = llava_init_context(¶ms, model); - debug_test_mrope_2d(); - debug_dump_img_embed(ctx_llava); + // debug_test_mrope_2d(); + debug_dump_img_embed(ctx_llava, model_output_type::final_layer); + // debug_dump_img_embed(ctx_llava, model_output_type::conv3d); + // debug_test_get_rows(); + // dump_win_attn_mask(); + // debug_patch_layout(); llama_perf_context_print(ctx_llava->ctx_llama); ctx_llava->model = NULL; From 7bd5eb751cbe8a4e188f5432dcf187c7d5dea209 Mon Sep 17 00:00:00 2001 From: HimariO Date: Sat, 15 Mar 2025 23:04:24 +0800 Subject: [PATCH 05/12] fix few incorrect tensor memory layout --- examples/llava/clip.cpp | 104 +++++++++++++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 22 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 0b3b8e70fc2ec..de71abcef9560 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -592,6 +592,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_add(ctx0, inp, inp_1); + + // ggml_build_forward_expand(gf, inp); + // ggml_free(ctx0); + // return gf; + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_reshape_4d( ctx0, inp, @@ -603,6 +608,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im inp = ggml_reshape_3d( ctx0, inp, hidden_size, patches_w * patches_h, batch_size); + + // ggml_build_forward_expand(gf, inp); + // ggml_free(ctx0); + // return gf; } else { inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); @@ -613,10 +622,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } - struct ggml_tensor * embeddings = inp; - struct ggml_tensor * pos_embed = nullptr; - struct ggml_tensor * window_mask = nullptr; - struct ggml_tensor * window_idx = nullptr; + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * window_mask = nullptr; + struct ggml_tensor * window_idx = nullptr; + struct ggml_tensor * inv_window_idx = nullptr; if (ctx->has_llava_projector) { // concat class_embeddings and patch_embeddings @@ -658,10 +668,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // pre-layernorm if (model.pre_ln_w) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "pre_ln"); + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); + } } std::vector embedding_stack; @@ -670,10 +687,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // loop over layers if (use_window_attn) { - window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); - ggml_set_name(window_idx, "window_idx"); - ggml_set_input(window_idx); - + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); // mask for window attention window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); ggml_set_name(window_mask, "window_mask"); @@ -682,12 +698,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] GGML_ASSERT(batch_size == 1); embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); - embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); - positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 4 / 4); - positions = ggml_get_rows(ctx0, positions, window_idx); + positions = ggml_reshape_2d(ctx0, positions, num_position_ids / 4, 4); + positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); + positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 16); + positions = ggml_get_rows(ctx0, positions, inv_window_idx); + positions = ggml_reshape_2d(ctx0, positions, 4, num_position_ids / 4); + positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); positions = ggml_reshape_1d(ctx0, positions, num_position_ids); + + // ggml_build_forward_expand(gf, embeddings); + // ggml_free(ctx0); + // return gf; } for (int il = 0; il < ctx->max_feature_layer; il++) { @@ -711,6 +735,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); } + // if ( il == 0) { + // // build the graph + // ggml_build_forward_expand(gf, cur); + // ggml_free(ctx0); + // return gf; + // } // self-attention { @@ -753,7 +783,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); } else { KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); + + // KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrt((float)d_head)); + // KQ = ggml_add(ctx0, KQ, window_mask); + // KQ = ggml_soft_max_inplace(ctx0, KQ); } + // if ( il == 0) { + // // build the graph + // ggml_build_forward_expand(gf, KQ); + // ggml_free(ctx0); + // return gf; + // } struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); @@ -769,6 +809,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, cur, embeddings); embeddings = cur; // embeddings = residual, cur = hidden_states + // if ( il == 0) { + // // build the graph + // ggml_build_forward_expand(gf, cur); + // ggml_free(ctx0); + // return gf; + // } // layernorm2 if (ctx->use_rms_norm) { @@ -820,8 +866,19 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; + + // if ( il == 0) { + // // build the graph + // ggml_build_forward_expand(gf, embeddings); + // ggml_free(ctx0); + // return gf; + // } } + // ggml_build_forward_expand(gf, embeddings); + // ggml_free(ctx0); + // return gf; + // post-layernorm if (model.post_ln_w) { if (ctx->use_rms_norm) { @@ -1149,14 +1206,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } if (use_window_attn) { - struct ggml_tensor * inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); - ggml_set_name(inv_window_idx, "inv_window_idx"); - ggml_set_input(inv_window_idx); + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] GGML_ASSERT(batch_size == 1); embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); - embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); } @@ -2657,6 +2714,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_qwen2vl_merger) { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + if (positions) { const int pw = image_size_width / patch_size; const int ph = image_size_height / patch_size; @@ -2681,6 +2739,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); + } } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { // do nothing @@ -2719,7 +2778,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; - const int grid_window = hparams.attn_window_size / hparams.patch_size / merge_ratio; + const int grid_window = hparams.attn_window_size / patch_size / merge_ratio; const int ipw = image_size_width / patch_size; const int iph = image_size_height / patch_size; /* @@ -2764,9 +2823,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); - ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); - ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + + if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); } ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); From 9167bd251bb5edd0f68227b7f6e65a6c1eab0e2a Mon Sep 17 00:00:00 2001 From: HimariO Date: Sun, 16 Mar 2025 00:35:19 +0800 Subject: [PATCH 06/12] move position id remap out of ggml to avoid int32 cuda operations --- examples/llava/clip.cpp | 101 +++++++++++++++++++++++++++------ examples/llava/qwen2vl-cli.cpp | 48 ++++++++-------- 2 files changed, 107 insertions(+), 42 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index de71abcef9560..cdded38d37b77 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -27,6 +27,7 @@ #include #include #include +#include struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; @@ -701,13 +702,13 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); - positions = ggml_reshape_2d(ctx0, positions, num_position_ids / 4, 4); - positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); - positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 16); - positions = ggml_get_rows(ctx0, positions, inv_window_idx); - positions = ggml_reshape_2d(ctx0, positions, 4, num_position_ids / 4); - positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); - positions = ggml_reshape_1d(ctx0, positions, num_position_ids); + // positions = ggml_reshape_2d(ctx0, positions, num_position_ids / 4, 4); + // positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); + // positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 16); + // positions = ggml_get_rows(ctx0, positions, inv_window_idx); + // positions = ggml_reshape_2d(ctx0, positions, 4, num_position_ids / 4); + // positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); + // positions = ggml_reshape_1d(ctx0, positions, num_position_ids); // ggml_build_forward_expand(gf, embeddings); // ggml_free(ctx0); @@ -2713,33 +2714,97 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } if (ctx->has_qwen2vl_merger) { + /* + pw * ph = number of tokens output by ViT after apply patch merger + ipw * ipw = number of vision token been processed inside ViT + */ + const int merge_ratio = 2; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + + std::vector idx(ph * pw); + std::vector inv_idx(ph * pw); + + if (hparams.attn_window_size > 0) { + struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); + struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); + struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); + + const int grid_window = hparams.attn_window_size / patch_size / merge_ratio; + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y+=grid_window) + { + for (int x = 0; x < pw; x+=grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx[src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + } else { + std::iota(idx.begin(), idx.end(), 0); + std::iota(inv_idx.begin(), inv_idx.end(), 0); + } + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - if (positions) { - const int pw = image_size_width / patch_size; - const int ph = image_size_height / patch_size; + // const int pw = image_size_width / patch_size; + // const int ph = image_size_height / patch_size; + const int mpow = (merge_ratio * merge_ratio); int* positions_data = (int*)malloc(ggml_nbytes(positions)); int ptr = 0; - for (int y = 0; y < ph; y+=2) + for (int y = 0; y < iph; y+=merge_ratio) { - for (int x = 0; x < pw; x+=2) + for (int x = 0; x < ipw; x+=merge_ratio) { for (int dy = 0; dy < 2; dy++) { for (int dx = 0; dx < 2; dx++) { - positions_data[ptr] = y + dy; - positions_data[num_patches + ptr] = x + dx; - positions_data[num_patches * 2 + ptr] = y + dy; - positions_data[num_patches * 3 + ptr] = x + dx; + auto remap = idx[ptr / mpow]; + remap = remap * mpow + (ptr % mpow); + + positions_data[remap] = y + dy; + positions_data[num_patches + remap] = x + dx; + positions_data[num_patches * 2 + remap] = y + dy; + positions_data[num_patches * 3 + remap] = x + dx; ptr++; } } } } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + if (positions) ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - } } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { // do nothing diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index a870bc4be3db5..f95677eef9e8a 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -370,14 +370,14 @@ static void debug_test_mrope_2d() { // 1. Initialize backend ggml_backend_t backend = NULL; std::string backend_name = ""; -#ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - backend_name = "cuda"; - if (!backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } -#endif +// #ifdef GGML_USE_CUDA +// fprintf(stderr, "%s: using CUDA backend\n", __func__); +// backend = ggml_backend_cuda_init(0); // init device 0 +// backend_name = "cuda"; +// if (!backend) { +// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); +// } +// #endif // if there aren't GPU Backends fallback to CPU backend if (!backend) { backend = ggml_backend_cpu_init(); @@ -490,14 +490,14 @@ static void debug_patch_layout() { // 1. Initialize backend ggml_backend_t backend = NULL; std::string backend_name = ""; -#ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - backend_name = "cuda"; - if (!backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } -#endif +// #ifdef GGML_USE_CUDA +// fprintf(stderr, "%s: using CUDA backend\n", __func__); +// backend = ggml_backend_cuda_init(0); // init device 0 +// backend_name = "cuda"; +// if (!backend) { +// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); +// } +// #endif // if there aren't GPU Backends fallback to CPU backend if (!backend) { backend = ggml_backend_cpu_init(); @@ -615,14 +615,14 @@ static void debug_test_get_rows() { // 1. Initialize backend ggml_backend_t backend = NULL; std::string backend_name = ""; -#ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - backend_name = "cuda"; - if (!backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } -#endif +// #ifdef GGML_USE_CUDA +// fprintf(stderr, "%s: using CUDA backend\n", __func__); +// backend = ggml_backend_cuda_init(0); // init device 0 +// backend_name = "cuda"; +// if (!backend) { +// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); +// } +// #endif // if there aren't GPU Backends fallback to CPU backend if (!backend) { backend = ggml_backend_cpu_init(); From bd518bf1c3097ec05aaf378cd146e91cbbd615ed Mon Sep 17 00:00:00 2001 From: HimariO Date: Tue, 1 Apr 2025 21:07:36 +0800 Subject: [PATCH 07/12] cleaning up --- examples/llava/clip.cpp | 56 ++---------------------------- examples/llava/qwen2_vl_surgery.py | 2 +- examples/llava/qwen2vl-cli.cpp | 10 +++--- 3 files changed, 8 insertions(+), 60 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index cdded38d37b77..5db2f6afc9345 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -594,10 +594,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_add(ctx0, inp, inp_1); - // ggml_build_forward_expand(gf, inp); - // ggml_free(ctx0); - // return gf; - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_reshape_4d( ctx0, inp, @@ -609,10 +605,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im inp = ggml_reshape_3d( ctx0, inp, hidden_size, patches_w * patches_h, batch_size); - - // ggml_build_forward_expand(gf, inp); - // ggml_free(ctx0); - // return gf; } else { inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); @@ -701,18 +693,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); - - // positions = ggml_reshape_2d(ctx0, positions, num_position_ids / 4, 4); - // positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); - // positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 16); - // positions = ggml_get_rows(ctx0, positions, inv_window_idx); - // positions = ggml_reshape_2d(ctx0, positions, 4, num_position_ids / 4); - // positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3)); - // positions = ggml_reshape_1d(ctx0, positions, num_position_ids); - - // ggml_build_forward_expand(gf, embeddings); - // ggml_free(ctx0); - // return gf; } for (int il = 0; il < ctx->max_feature_layer; il++) { @@ -736,12 +716,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); } - // if ( il == 0) { - // // build the graph - // ggml_build_forward_expand(gf, cur); - // ggml_free(ctx0); - // return gf; - // } // self-attention { @@ -784,17 +758,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); } else { KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); - // KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrt((float)d_head)); // KQ = ggml_add(ctx0, KQ, window_mask); // KQ = ggml_soft_max_inplace(ctx0, KQ); } - // if ( il == 0) { - // // build the graph - // ggml_build_forward_expand(gf, KQ); - // ggml_free(ctx0); - // return gf; - // } struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); @@ -810,12 +777,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, cur, embeddings); embeddings = cur; // embeddings = residual, cur = hidden_states - // if ( il == 0) { - // // build the graph - // ggml_build_forward_expand(gf, cur); - // ggml_free(ctx0); - // return gf; - // } // layernorm2 if (ctx->use_rms_norm) { @@ -867,19 +828,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - - // if ( il == 0) { - // // build the graph - // ggml_build_forward_expand(gf, embeddings); - // ggml_free(ctx0); - // return gf; - // } } - // ggml_build_forward_expand(gf, embeddings); - // ggml_free(ctx0); - // return gf; - // post-layernorm if (model.post_ln_w) { if (ctx->use_rms_norm) { @@ -2777,9 +2727,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - - // const int pw = image_size_width / patch_size; - // const int ph = image_size_height / patch_size; const int mpow = (merge_ratio * merge_ratio); int* positions_data = (int*)malloc(ggml_nbytes(positions)); @@ -2792,6 +2739,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (int dx = 0; dx < 2; dx++) { auto remap = idx[ptr / mpow]; remap = remap * mpow + (ptr % mpow); + // auto remap = ptr; positions_data[remap] = y + dy; positions_data[num_patches + remap] = x + dx; @@ -2803,7 +2751,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - if (positions) ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index 8f7a94e5c3797..9d4ad8932c07f 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -102,7 +102,7 @@ def main(args): np_dtype = np.float32 ftype = 0 elif args.data_type == 'fp16': - dtype = torch.float32 + dtype = torch.float16 np_dtype = np.float16 ftype = 1 else: diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index f95677eef9e8a..4598fab25f79b 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -771,10 +771,10 @@ enum model_output_type { }; static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) { - int ih = 140; - int iw = 196; - // int ih = 56; - // int iw = 56; + constexpr int ih = 140; + constexpr int iw = 196; + // constexpr int ih = 56; + // constexpr int iw = 56; // int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); int n_embd = 1280; int merge = 1; @@ -954,7 +954,7 @@ int main(int argc, char ** argv) { // debug_test_mrope_2d(); debug_dump_img_embed(ctx_llava, model_output_type::final_layer); - // debug_dump_img_embed(ctx_llava, model_output_type::conv3d); + // debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer); // debug_test_get_rows(); // dump_win_attn_mask(); // debug_patch_layout(); From edd35fe55ddebafc3df3485419d46b2189ac4fd0 Mon Sep 17 00:00:00 2001 From: HimariO Date: Thu, 3 Apr 2025 22:52:44 +0800 Subject: [PATCH 08/12] ignore transformers Qwen2_5_xxx type check --- examples/llava/qwen2_vl_surgery.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index 9d4ad8932c07f..0a47a719fa544 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -6,11 +6,11 @@ from gguf import * from transformers import ( AutoProcessor, - Qwen2VLForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, - Qwen2VLProcessor, Qwen2VLConfig, - Qwen2_5_VLConfig, + Qwen2VLProcessor, + Qwen2VLForConditionalGeneration, + Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue] + Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue] ) From bb6fdc8e87cbc78e9236b99565ace51f035bb9cd Mon Sep 17 00:00:00 2001 From: HimariO Date: Fri, 4 Apr 2025 15:18:01 +0800 Subject: [PATCH 09/12] reuse qwen2vl converter instead --- convert_hf_to_gguf.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index be97c8da10691..2bf97475f78dd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2375,11 +2375,6 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: yield name, data -@Model.register("Qwen2_5_VLForConditionalGeneration") -class Qwen25VLModel(Qwen2VLModel): - model_arch = gguf.MODEL_ARCH.QWEN2VL - - @Model.register("WavTokenizerDec") class WavTokenizerDecModel(Model): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC From 1acb5e3621f07fb1da2aedcbd923494ed8833fc4 Mon Sep 17 00:00:00 2001 From: HimariO Date: Fri, 4 Apr 2025 15:21:04 +0800 Subject: [PATCH 10/12] remove not so often use `qwen2vl-cli` debug functions --- examples/llava/qwen2vl-cli.cpp | 277 --------------------------------- 1 file changed, 277 deletions(-) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 4598fab25f79b..810bfa37f37c7 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -486,280 +486,6 @@ static void debug_test_mrope_2d() { ggml_backend_free(backend); } -static void debug_patch_layout() { - // 1. Initialize backend - ggml_backend_t backend = NULL; - std::string backend_name = ""; -// #ifdef GGML_USE_CUDA -// fprintf(stderr, "%s: using CUDA backend\n", __func__); -// backend = ggml_backend_cuda_init(0); // init device 0 -// backend_name = "cuda"; -// if (!backend) { -// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); -// } -// #endif - // if there aren't GPU Backends fallback to CPU backend - if (!backend) { - backend = ggml_backend_cpu_init(); - backend_name = "cpu"; - } - - // Calculate the size needed to allocate - size_t ctx_size = 0; - ctx_size += 2 * ggml_tensor_overhead(); // tensors - // no need to allocate anything else! - - // 2. Allocate `ggml_context` to store tensor data - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() - }; - struct ggml_context * ctx = ggml_init(params); - - const int patches_w = 14; - const int patches_h = 10; - const int c = 2; - const int batch_size = 1; - struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, patches_w, patches_h, c, batch_size); - ggml_set_name(inp_raw, "inp_raw"); - ggml_set_input(inp_raw); - - - std::vector dummy_q; - dummy_q.resize(patches_w * patches_h * c * batch_size); - for (size_t i = 0; i < patches_h * patches_w * c; i++) - { - dummy_q[i] = i; - } - - // std::fill(dummy_q.begin(), dummy_q.end(), 0.1); - // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); - - // 4. Allocate a `ggml_backend_buffer` to store all tensors - ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - - // 5. Copy tensor data from main memory (RAM) to backend buffer - ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw)); - - // 6. Create a `ggml_cgraph` for mul_mat operation - struct ggml_cgraph * gf = NULL; - struct ggml_context * ctx0 = NULL; - - // create a temporally context to build the graph - struct ggml_init_params params0 = { - /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - ctx0 = ggml_init(params0); - gf = ggml_new_graph(ctx0); - /* - Compute graph - */ - struct ggml_tensor * inp = ggml_cont(ctx0, ggml_permute(ctx0, inp_raw, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] - - inp = ggml_reshape_4d( - ctx0, inp, - c * 2, patches_w / 2, patches_h, batch_size); - inp = ggml_reshape_4d( - ctx0, inp, - c * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); - inp = ggml_reshape_3d( - ctx0, inp, - c, patches_w * patches_h, batch_size); - - // Add "result" tensor and all of its dependencies to the cgraph - ggml_build_forward_expand(gf, inp); - - // 7. Create a `ggml_gallocr` for cgraph computation - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); - ggml_gallocr_alloc_graph(allocr, gf); - - // 9. Run the computation - int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, n_threads); - } - ggml_backend_graph_compute(backend, gf); - - // 10. Retrieve results (output tensors) - // in this example, output tensor is always the last tensor in the graph - struct ggml_tensor * result = inp; - // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1]; - float * result_data = (float *)malloc(ggml_nbytes(result)); - // because the tensor data is stored in device buffer, we need to copy it back to RAM - ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); - const std::string bin_file = "patch_layout_" + backend_name +".bin"; - std::ofstream outFile(bin_file, std::ios::binary); - - if (outFile.is_open()) { - outFile.write(reinterpret_cast(result_data), ggml_nbytes(result)); - outFile.close(); - std::cout << "Data successfully written to " + bin_file << std::endl; - } else { - std::cerr << "Error opening file!" << std::endl; - } - - free(result_data); - // 11. Free memory and exit - ggml_free(ctx0); - ggml_gallocr_free(allocr); - ggml_free(ctx); - ggml_backend_buffer_free(buffer); - ggml_backend_free(backend); -} - -static void debug_test_get_rows() { - // 1. Initialize backend - ggml_backend_t backend = NULL; - std::string backend_name = ""; -// #ifdef GGML_USE_CUDA -// fprintf(stderr, "%s: using CUDA backend\n", __func__); -// backend = ggml_backend_cuda_init(0); // init device 0 -// backend_name = "cuda"; -// if (!backend) { -// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); -// } -// #endif - // if there aren't GPU Backends fallback to CPU backend - if (!backend) { - backend = ggml_backend_cpu_init(); - backend_name = "cpu"; - } - - // Calculate the size needed to allocate - size_t ctx_size = 0; - ctx_size += 128 * ggml_tensor_overhead(); // tensors - // no need to allocate anything else! - - // 2. Allocate `ggml_context` to store tensor data - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() - }; - struct ggml_context * ctx = ggml_init(params); - - const int tokens = 30; - struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 3, tokens * 2); - ggml_set_name(inp_raw, "inp_raw"); - ggml_set_input(inp_raw); - - struct ggml_tensor * pos = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 4, tokens); - // struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, tokens * 4); - ggml_set_name(pos, "pos"); - ggml_set_input(pos); - - struct ggml_tensor * ind = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, tokens); - ggml_set_name(ind, "ind"); - ggml_set_input(ind); - - struct ggml_tensor * ind_2d = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 1, tokens); - ggml_set_name(ind_2d, "ind_2d"); - ggml_set_input(ind_2d); - - std::vector dummy_q; - dummy_q.resize(128 * 3 * inp_raw->ne[2]); - for (int i = 0; i < inp_raw->ne[2]; i ++) { - for (int j = 0; j < 3; j ++) { - int offset = i * 128 * 3 + j * 128; - std::fill(dummy_q.begin() + offset, dummy_q.begin() + offset + 128, 0.1 * i); - } - } - // std::fill(dummy_q.begin(), dummy_q.end(), 0.1); - // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); - - std::vector pos_id; - pos_id.resize(tokens * 4); - for (int i = 0; i < tokens; i ++) { - pos_id[i] = i; - pos_id[i + tokens * 1] = i + 10; - pos_id[i + tokens * 2] = i + 20; - pos_id[i + tokens * 3] = i + 30; - } - - std::vector remap_ind; - remap_ind.resize(tokens * 4); - for (int i = 0; i < tokens; i ++) { - remap_ind[i] = tokens - i - 1; - } - - // 4. Allocate a `ggml_backend_buffer` to store all tensors - ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - - // 5. Copy tensor data from main memory (RAM) to backend buffer - ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw)); - ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos)); - ggml_backend_tensor_set(ind, remap_ind.data(), 0, ggml_nbytes(ind)); - ggml_backend_tensor_set(ind_2d, remap_ind.data(), 0, ggml_nbytes(ind_2d)); - - // 6. Create a `ggml_cgraph` for mul_mat operation - struct ggml_cgraph * gf = NULL; - struct ggml_context * ctx_cgraph = NULL; - - // create a temporally context to build the graph - struct ggml_init_params params0 = { - /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - ctx_cgraph = ggml_init(params0); - gf = ggml_new_graph(ctx_cgraph); - - // ne = [128, 1, 30, 1] - auto x = ggml_reshape_2d(ctx_cgraph, inp_raw, 128 * 3 * 2, tokens); - struct ggml_tensor * result0 = ggml_get_rows( - ctx_cgraph, x, ind); - result0 = ggml_reshape_3d(ctx_cgraph, result0, 128, 3, tokens * 2); - - struct ggml_tensor * result1 = ggml_get_rows( - ctx_cgraph, pos, ind); - - // Add "result" tensor and all of its dependencies to the cgraph - ggml_build_forward_expand(gf, result0); - ggml_build_forward_expand(gf, result1); - - // 7. Create a `ggml_gallocr` for cgraph computation - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); - ggml_gallocr_alloc_graph(allocr, gf); - - // 9. Run the computation - int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, n_threads); - } - ggml_backend_graph_compute(backend, gf); - - // 10. Retrieve results (output tensors) - // in this example, output tensor is always the last tensor in the graph - struct ggml_tensor * result = result0; - // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1]; - float * result_data = (float *)malloc(ggml_nbytes(result)); - // because the tensor data is stored in device buffer, we need to copy it back to RAM - ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); - const std::string bin_file = "getrows_" + backend_name +"_0.bin"; - std::ofstream outFile(bin_file, std::ios::binary); - - if (outFile.is_open()) { - outFile.write(reinterpret_cast(result_data), ggml_nbytes(result)); - outFile.close(); - std::cout << "Data successfully written to " + bin_file << std::endl; - } else { - std::cerr << "Error opening file!" << std::endl; - } - - free(result_data); - // 11. Free memory and exit - ggml_free(ctx_cgraph); - ggml_gallocr_free(allocr); - ggml_free(ctx); - ggml_backend_buffer_free(buffer); - ggml_backend_free(backend); -} - - enum model_output_type { conv3d, patch_embed, @@ -955,9 +681,6 @@ int main(int argc, char ** argv) { // debug_test_mrope_2d(); debug_dump_img_embed(ctx_llava, model_output_type::final_layer); // debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer); - // debug_test_get_rows(); - // dump_win_attn_mask(); - // debug_patch_layout(); llama_perf_context_print(ctx_llava->ctx_llama); ctx_llava->model = NULL; From ef5a75fadc7801f5c9da14122e15a3a77a20f3a3 Mon Sep 17 00:00:00 2001 From: HimariO Date: Fri, 4 Apr 2025 15:44:31 +0800 Subject: [PATCH 11/12] remove commented-out code blocks --- examples/llava/clip.cpp | 7 +--- examples/llava/qwen2vl-cli.cpp | 70 ---------------------------------- 2 files changed, 1 insertion(+), 76 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 5db2f6afc9345..feeb351399fc7 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -758,9 +758,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); } else { KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); - // KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrt((float)d_head)); - // KQ = ggml_add(ctx0, KQ, window_mask); - // KQ = ggml_soft_max_inplace(ctx0, KQ); } struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); @@ -2739,9 +2736,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (int dx = 0; dx < 2; dx++) { auto remap = idx[ptr / mpow]; remap = remap * mpow + (ptr % mpow); - // auto remap = ptr; - positions_data[remap] = y + dy; + positions_data[remap] = y + dy; positions_data[num_patches + remap] = x + dx; positions_data[num_patches * 2 + remap] = y + dy; positions_data[num_patches * 3 + remap] = x + dx; @@ -2836,7 +2832,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 810bfa37f37c7..cf42710869191 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -567,76 +567,6 @@ static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_ } } - -static void dump_win_attn_mask() { - const int image_size_width = 196; - const int image_size_height = 140; - const int patch_size = 14; - const int attn_window_size = 112; - - const int merge_ratio = 2; - const int ipw = image_size_width / patch_size; - const int iph = image_size_height / patch_size; - const int pw = image_size_width / patch_size / merge_ratio; - const int ph = image_size_height / patch_size / merge_ratio; - const int grid_window = attn_window_size / patch_size / merge_ratio; - /* - pw * ph = number of tokens output by ViT after apply patch merger - ipw * ipw = number of vision token been processed inside ViT - */ - - std::vector idx(ph * pw); - std::vector inv_idx(ph * pw); - int dst = 0; - // [num_vision_tokens, num_vision_tokens] attention mask tensor - int ne = pow(ipw * iph, 2); - std::vector mask(ne, std::numeric_limits::lowest()); - int mask_row = 0; - - for (int y = 0; y < ph; y+=grid_window) - { - for (int x = 0; x < pw; x+=grid_window) - { - const int win_h = std::min(grid_window, ph - y); - const int win_w = std::min(grid_window, pw - x); - const int dst_0 = dst; - // group all tokens belong to the same window togather (to a continue range) - for (int dy = 0; dy < win_h; dy++) { - for (int dx = 0; dx < win_w; dx++) { - const int src = (y + dy) * pw + (x + dx); - assert(src < (int)idx.size()); - assert(dst < (int)inv_idx.size()); - idx[src] = dst; - inv_idx[dst] = src; - dst++; - } - } - - for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { - int row_offset = mask_row * (ipw * iph); - std::fill( - mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), - mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), - 0.0); - mask_row++; - } - } - } - - auto output_path = "win_attn_mask_fp32.bin"; - - std::ofstream outFile(output_path, std::ios::binary); - if (outFile.is_open()) { - outFile.write(reinterpret_cast(mask.data()), ne * sizeof(float)); - - outFile.close(); - std::cout << "Data successfully written to " << output_path << std::endl; - } else { - std::cerr << "Error opening file!" << std::endl; - } -} - - #endif From b27a8dcb0ed354b4d025d4a24a3921edceb1e5a7 Mon Sep 17 00:00:00 2001 From: HimariO Date: Mon, 7 Apr 2025 22:07:56 +0800 Subject: [PATCH 12/12] fix attn weight scaling after rebase --- examples/llava/clip.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index feeb351399fc7..d130c4766dfd2 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -757,7 +757,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im if (full_attn) { KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); } else { - KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); + KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f); } struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);