Skip to content

Commit f027196

Browse files
committed
graph : continue to explicitly cast K and V to F16
1 parent 34f95c3 commit f027196

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/llama-graph.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
13521352
v = ggml_transpose(ctx0, v);
13531353
}
13541354

1355+
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1356+
if (k->type == GGML_TYPE_F32) {
1357+
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1358+
}
1359+
1360+
if (v->type == GGML_TYPE_F32) {
1361+
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1362+
}
1363+
13551364
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
13561365
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
13571366
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);

0 commit comments

Comments
 (0)