Skip to content

Commit ea677dc

Browse files
committed
fix vit pos embed, deepstack and mrope-interleaved!
love from qwen team!
1 parent ab45b1a commit ea677dc

File tree

6 files changed

+347
-40
lines changed

6 files changed

+347
-40
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5487,6 +5487,12 @@ static void ggml_mrope_cache_init(
54875487
int sec_e = sections[2] + sec_w;
54885488
GGML_ASSERT(sect_dims <= ne0);
54895489

5490+
// Qwen3VL: interleaved mrope, currently judged by the number of sections
5491+
bool is_interleaved_mrope = false;
5492+
if (sections[0] == 24 && sections[1] == 20 && sections[2] == 20) {
5493+
is_interleaved_mrope = true;
5494+
}
5495+
54905496
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
54915497
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
54925498

@@ -5510,14 +5516,25 @@ static void ggml_mrope_cache_init(
55105516

55115517
float theta = theta_t;
55125518

5513-
if (sector >= sections[0] && sector < sec_w) {
5514-
theta = theta_h;
5515-
}
5516-
else if (sector >= sec_w && sector < sec_w + sections[2]) {
5517-
theta = theta_w;
5518-
}
5519-
else if (sector >= sec_w + sections[2]) {
5520-
theta = theta_e;
5519+
if (is_interleaved_mrope) {
5520+
// thwthwthw...ttt
5521+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
5522+
theta = theta_h;
5523+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5524+
theta = theta_w;
5525+
} else {
5526+
theta = theta_e;
5527+
}
5528+
} else {
5529+
if (sector >= sections[0] && sector < sec_w) {
5530+
theta = theta_h;
5531+
}
5532+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
5533+
theta = theta_w;
5534+
}
5535+
else if (sector >= sec_w + sections[2]) {
5536+
theta = theta_e;
5537+
}
55215538
}
55225539

55235540
rope_yarn(

ggml/src/ggml-cuda/rope.cu

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,30 @@ static __global__ void rope_multi(
151151
const int sec_w = sections.v[1] + sections.v[0];
152152
const int sector = (i0 / 2) % sect_dims;
153153

154+
bool is_interleaved_mrope = (sections.v[0] == 24 && sections.v[1] == 20 && sections.v[2] == 20);
155+
154156
float theta_base = 0.0;
155-
if (sector < sections.v[0]) {
156-
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
157-
}
158-
else if (sector >= sections.v[0] && sector < sec_w) {
159-
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
160-
}
161-
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
162-
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
163-
}
164-
else if (sector >= sec_w + sections.v[2]) {
165-
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
157+
if (is_interleaved_mrope) {
158+
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
159+
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
160+
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
161+
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
162+
} else { // t
163+
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
164+
}
165+
} else {
166+
if (sector < sections.v[0]) {
167+
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
168+
}
169+
else if (sector >= sections.v[0] && sector < sec_w) {
170+
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
171+
}
172+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
173+
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
174+
}
175+
else if (sector >= sec_w + sections.v[2]) {
176+
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
177+
}
166178
}
167179

168180
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

src/llama-graph.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,55 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
11411141
return cur;
11421142
}
11431143

1144+
// input embeddings with optional lora for qwen3vl series model
1145+
ggml_tensor * llm_graph_context::build_qwen3vl_inp_embd(ggml_tensor * tok_embd) const {
1146+
const int64_t n_embd_full = hparams.n_embd; // main + 3 deepstack layers
1147+
const int64_t n_embd_main = n_embd_full / 4;
1148+
1149+
auto inp = std::make_unique<llm_graph_input_embd>();
1150+
ggml_tensor * cur = nullptr;
1151+
1152+
if (ubatch.token) {
1153+
// Pure text input: expand to 4*n_embd with zero deepstack
1154+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1155+
ggml_set_input(inp->tokens);
1156+
res->t_tokens = inp->tokens;
1157+
1158+
// Get main embedding from token IDs
1159+
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1160+
1161+
// Apply LoRA if needed
1162+
for (const auto & lora : *loras) {
1163+
llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1164+
if (lw == nullptr) continue;
1165+
1166+
const float adapter_scale = lora.second;
1167+
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1168+
ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1169+
ctx0, lw->b,
1170+
ggml_get_rows(ctx0, lw->a, inp->tokens)
1171+
), scale);
1172+
cur = ggml_add(ctx0, cur, inpL_delta);
1173+
}
1174+
} else {
1175+
// Custom embedding input (e.g., from image): assume already 4*n_embd
1176+
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_full, ubatch.n_tokens);
1177+
ggml_set_input(inp->embd);
1178+
cur = inp->embd;
1179+
}
1180+
1181+
// Apply embedding scale if needed (e.g., Granite)
1182+
if (hparams.f_embedding_scale != 0.0f) {
1183+
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1184+
}
1185+
1186+
// Register to graph and input system
1187+
cb(cur, "inp_embd_qwen3vl", -1);
1188+
res->add_input(std::move(inp));
1189+
1190+
return cur;
1191+
}
1192+
11441193
ggml_tensor * llm_graph_context::build_inp_pos() const {
11451194
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
11461195

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,8 @@ struct llm_graph_context {
684684
ggml_tensor * build_inp_pos_bucket_dec() const;
685685
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
686686

687+
ggml_tensor * build_qwen3vl_inp_embd(ggml_tensor * tok_embd) const;
688+
687689
//
688690
// attention
689691
//

0 commit comments

Comments
 (0)